From 5ddcb306c336b9380d00581e549248406ee3f8ea Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 16 Oct 2024 22:03:00 -0700 Subject: [PATCH 01/34] WIP: router chunked prefill --- proto/generate.proto | 23 +++++- router/client/src/client.rs | 10 ++- router/client/src/sharded_client.rs | 3 +- router/src/batch.rs | 14 +++- router/src/health.rs | 5 +- router/src/infer.rs | 119 +++++++++++++++++++++------- 6 files changed, 134 insertions(+), 40 deletions(-) diff --git a/proto/generate.proto b/proto/generate.proto index 19dc70076..9ba083efc 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -57,6 +57,7 @@ message InfoResponse { bool supports_generation = 8; bool supports_embeddings = 9; bool supports_classification = 10; + bool supports_chunking = 11; } /// Empty request @@ -156,8 +157,12 @@ message Request { repeated uint32 blocks = 9; /// Paged attention slots repeated uint32 slots = 10; - /// Prefix length that can be retrieved from the KV cache - uint32 prefix_len = 11; + /// Tokens that can be retrieved from the KV cache. + /// This value is set for the first prefill and never reset + uint32 cache_len = 12; + /// Chunk of tokens that must be computed for the first prefill + /// This value is set for the first prefill and never reset + optional uint32 chunk_len = 13; } message Batch { @@ -182,6 +187,8 @@ message CachedBatch { uint32 size = 3; /// Maximum number of tokens this batch will grow to uint32 max_tokens = 4; + /// Number of tokens in the next forward + uint32 current_tokens = 5; } enum FinishReason { @@ -261,6 +268,8 @@ message FilterBatchResponse { message PrefillRequest { /// Batch Batch batch = 1; + /// Optional cached batch + CachedBatch cached_batch = 2; } message PrefillResponse { @@ -268,6 +277,16 @@ message PrefillResponse { repeated Generation generations = 1; /// Next batch (cached) optional CachedBatch batch = 2; + + // TODO(travis): add timings + // /// Forward elapsed time in nanoseconds + // uint64 forward_ns = 3; + // /// Decode elapsed time in nanoseconds + // uint64 decode_ns = 4; + // /// Total elapsed time in nanoseconds + // uint64 total_ns = 5; + // /// Concatenate elapsed time in nanoseconds + // optional uint64 concat_ns = 6; } message DecodeRequest { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index f6e663cc7..7007a26aa 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -143,7 +143,8 @@ impl Client { // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], - prefix_len: 0, + cache_len: 0, + chunk_len: None, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, @@ -195,8 +196,13 @@ impl Client { pub async fn prefill( &mut self, batch: Batch, + cached_batch: Option, ) -> Result<(Vec, Option)> { - let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); + let request = tonic::Request::new(PrefillRequest { + batch: Some(batch), + cached_batch, + }) + .inject_context(); let response = self.stub.prefill(request).await?.into_inner(); Ok((response.generations, response.batch)) } diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index b0cb08b28..92651432f 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -124,11 +124,12 @@ impl ShardedClient { pub async fn prefill( &mut self, batch: Batch, + cached_batch: Option, ) -> Result<(Vec, Option)> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.prefill(batch.clone()))) + .map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone()))) .collect(); let results: Result, Option)>> = join_all(futures).await.into_iter().collect(); diff --git a/router/src/batch.rs b/router/src/batch.rs index de3032f8d..e908c899c 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -274,6 +274,7 @@ pub(crate) trait BatchEntries: Sync + Send + Debug { &mut self, client: &mut ShardedClient, batch: Batch, + cached_batch: Option, span: Span, generation_health: &Arc, ) -> Option; @@ -341,7 +342,8 @@ impl BatchEntries for GenerateBatchEntries { adapter_index: adapter.index(), blocks, slots, - prefix_len, + cache_len: prefix_len, + chunk_len: None, }; self.state.add(id, entry, adapter, request_proto); @@ -385,12 +387,14 @@ impl BatchEntries for GenerateBatchEntries { &mut self, client: &mut ShardedClient, batch: Batch, + cached_batch: Option, span: Span, generation_health: &Arc, ) -> Option { prefill( client, batch, + cached_batch, &mut self.state.batch_entries, &generation_health, ) @@ -470,7 +474,8 @@ impl BatchEntries for EmbedBatchEntries { adapter_index: adapter.index(), blocks, slots, - prefix_len, + cache_len: prefix_len, + chunk_len: None, }; self.state.add(id, entry, adapter, request_proto); @@ -514,6 +519,7 @@ impl BatchEntries for EmbedBatchEntries { &mut self, client: &mut ShardedClient, batch: Batch, + cached_batch: Option, span: Span, generation_health: &Arc, ) -> Option { @@ -593,7 +599,8 @@ impl BatchEntries for ClassifyBatchEntries { adapter_index: adapter.index(), blocks, slots, - prefix_len, + cache_len: prefix_len, + chunk_len: None, }; self.state.add(id, entry, adapter, request_proto); @@ -637,6 +644,7 @@ impl BatchEntries for ClassifyBatchEntries { &mut self, client: &mut ShardedClient, batch: Batch, + cached_batch: Option, span: Span, generation_health: &Arc, ) -> Option { diff --git a/router/src/health.rs b/router/src/health.rs index 02aea0509..1c6aaed2d 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -73,7 +73,8 @@ impl Health { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), - prefix_len: 0, + cache_len: 0, + chunk_len: None, }; let batch = Batch { id: BATCH_ID, @@ -83,7 +84,7 @@ impl Health { max_blocks: 1, }; // Skips the queue - let value = self.client.prefill(batch).await.is_ok(); + let value = self.client.prefill(batch, None).await.is_ok(); // Update generation health self.generation_health.store(value, Ordering::SeqCst); value diff --git a/router/src/infer.rs b/router/src/infer.rs index 391dd1271..11eeb374f 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -899,6 +899,8 @@ async fn batching_task( adapter_scheduler: AdapterScheduler, eager_prefill: bool, ) { + let support_chunking = true; + // Infinite loop loop { // Fire if a new request comes in or an adapter becomes ready @@ -917,7 +919,7 @@ async fn batching_task( .await { let mut cached_batch = batch_entries - .process_first(&mut client, batch, span, &generation_health) + .process_first(&mut client, batch, None, span, &generation_health) .await; let mut waiting_tokens = 1; @@ -927,6 +929,7 @@ async fn batching_task( // Get current batch info let mut batch_size = batch.size; let batch_max_tokens = batch.max_tokens; + let current_tokens = batch.current_tokens; let mut batches = vec![batch]; metrics::gauge!("lorax_batch_current_size", batch_size as f64); metrics::gauge!("lorax_batch_current_max_tokens", batch_max_tokens as f64); @@ -935,24 +938,53 @@ async fn batching_task( // TODO(travis): can execute this more efficiently by making it event-driven adapter_scheduler.remove_errored_adapters().await; - let min_size = if waiting_tokens >= max_waiting_tokens || eager_prefill { - // If we didn't onboard any new requests since >= max_waiting_tokens, we try - // to add a new batch even though its size might be small - None + let mut token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); + let (min_size, max_size, prefill_token_budget) = if support_chunking { + // Since the next batch will be concatenated with the current batch, + // the current batch tokens must be subtracted to the prefill budget + let prefill_token_budget = + max_batch_prefill_tokens.saturating_sub(current_tokens); + // We can ignore min_size and max_size + // Models than rely on max_size cannot support chunking + // Regarding min_size, chunking allow us to consistently run at the compute + // bound, making min_size useless. + (None, None, prefill_token_budget) } else { - // Minimum batch size - Some((batch_size as f32 * waiting_served_ratio).floor() as usize) + let min_size = if waiting_tokens >= max_waiting_tokens { + // If we didn't onboard any new requests since >= max_waiting_tokens, we try + // to add a new batch even though its size might be small + None + } else { + // Minimum batch size + // TODO: temporarily disable to avoid incorrect deallocation + + // reallocation when using prefix caching. + Some((batch_size as f32 * waiting_served_ratio).floor() as usize) + }; + + let max_batch_size: Option = None; // TODO(travis) + let max_size = + max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); + + (min_size, max_size, max_batch_prefill_tokens) }; - let mut token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); + // let min_size = if waiting_tokens >= max_waiting_tokens || eager_prefill { + // // If we didn't onboard any new requests since >= max_waiting_tokens, we try + // // to add a new batch even though its size might be small + // None + // } else { + // // Minimum batch size + // Some((batch_size as f32 * waiting_served_ratio).floor() as usize) + // }; + let mut adapters_in_use = batch_entries.adapters_in_use(); // Try to get a new batch - while let Some((mut new_entries, new_batch, span)) = adapter_scheduler + while let Some((new_entries, new_batch, span)) = adapter_scheduler .next_batch( adapters_in_use.clone(), min_size, - max_batch_prefill_tokens, + prefill_token_budget, token_budget, ) .await @@ -967,37 +999,57 @@ async fn batching_task( if min_size.is_some() { metrics::increment_counter!("lorax_batch_concat", "reason" => "backpressure"); } else { - metrics::increment_counter!("lorax_batch_concat", "reason" => "wait_exceeded"); + if support_chunking { + metrics::increment_counter!("lorax_batch_concat", "reason" => "chunking") + } else { + metrics::increment_counter!("lorax_batch_concat", "reason" => "wait_exceeded") + }; } - batch_entries - .mut_state() - .batch_entries - .iter_mut() - .for_each(|(_, entry)| { - // Create a new span to add the info that this entry is waiting - // because a new batch is being computed - let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); - // Add relationships - span.follows_from(&entry_waiting_span); - entry_waiting_span.follows_from(&span); - // Update entry - entry.temp_span = Some(entry_waiting_span); - }); + let cached_batch = if support_chunking { + // Concat current batch to the new one + batches.pop() + } else { + // Request are waiting only if we don't support chunking + batch_entries.mut_state().batch_entries.iter_mut().for_each( + |(_, entry)| { + // Create a new span to add the info that this entry is waiting + // because a new batch is being computed + let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); + // Add relationships + span.follows_from(&entry_waiting_span); + entry_waiting_span.follows_from(&span); + // Update entry + entry.temp_span = Some(entry_waiting_span); + }, + ); + None + }; + + let new_adapters_in_use = new_entries.adapters_in_use(); + batch_entries.extend(new_entries); // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = new_entries - .process_first(&mut client, new_batch, span, &generation_health) + let new_cached_batch = batch_entries + .process_first( + &mut client, + new_batch, + cached_batch, + span, + &generation_health, + ) .await; - adapters_in_use.extend(new_entries.adapters_in_use()); + adapters_in_use.extend(new_adapters_in_use); // Reset waiting counter waiting_tokens = 1; // Extend current batch with the new batch if let Some(new_cached_batch) = new_cached_batch { - batch_entries.extend(new_entries); batches.push(new_cached_batch); + } else if support_chunking { + // New cached batch is empty, no work left + break; } if !eager_prefill { @@ -1040,6 +1092,7 @@ async fn batching_task( pub(crate) async fn prefill( client: &mut ShardedClient, batch: Batch, + cached_batch: Option, entries: &mut IntMap, generation_health: &Arc, ) -> Option { @@ -1047,7 +1100,7 @@ pub(crate) async fn prefill( let batch_id = batch.id; metrics::increment_counter!("lorax_batch_inference_count", "method" => "prefill"); - match client.prefill(batch).await { + match client.prefill(batch, cached_batch).await { Ok((generations, next_batch)) => { // Update health generation_health.store(true, Ordering::SeqCst); @@ -1057,6 +1110,12 @@ pub(crate) async fn prefill( // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; + // TODO(travis) + // if let Some(concat_duration) = timings.concat { + // metrics::histogram!("lorax_batch_concat_duration", "method" => "decode") + // .record(concat_duration.as_secs_f64()); + // } + metrics::histogram!("lorax_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); metrics::increment_counter!("lorax_batch_inference_success", "method" => "prefill"); next_batch From 1d76aa80730f4ed0f05b78ad9b3d1a57abaa9c3c Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 16 Oct 2024 22:21:34 -0700 Subject: [PATCH 02/34] WIP: Seqlen --- .../custom_modeling/flash_qwen2_modeling.py | 26 ++++++++++--------- server/lorax_server/utils/paged_attention.py | 4 ++- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py index abbda50da..805c4576d 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py @@ -6,6 +6,8 @@ from typing import List, Optional, Tuple +from lorax_server.utils.attention.common import Seqlen + # Flash attention imports import dropout_layer_norm import torch @@ -219,7 +221,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, adapter_data, @@ -270,7 +272,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -358,7 +360,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, adapter_data, @@ -374,7 +376,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, adapter_data, @@ -423,7 +425,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], adapter_data: AdapterBatchData, @@ -444,7 +446,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, adapter_data, @@ -482,7 +484,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, @@ -495,7 +497,7 @@ def forward( # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values max_s = min(self.max_past, max_s) - input_lengths = torch.clamp(input_lengths, max=self.max_past) + seqlen = seqlen.clamp(max=self.max_past) hidden_states = self.model( input_ids, @@ -504,7 +506,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, adapter_data, @@ -537,7 +539,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, @@ -550,7 +552,7 @@ def forward( # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values max_s = min(self.max_past, max_s) - input_lengths = torch.clamp(input_lengths, max=self.max_past) + seqlen = seqlen.clamp(max=self.max_past) hidden_states = self.model( input_ids, @@ -559,7 +561,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, adapter_data, diff --git a/server/lorax_server/utils/paged_attention.py b/server/lorax_server/utils/paged_attention.py index 34d11e843..60bee0e43 100644 --- a/server/lorax_server/utils/paged_attention.py +++ b/server/lorax_server/utils/paged_attention.py @@ -1,5 +1,6 @@ from typing import Optional +from lorax_server.utils.attention.common import Seqlen import torch from lorax_server.utils.import_utils import SYSTEM @@ -54,7 +55,7 @@ def attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, softcap: Optional[float] = None, ): @@ -86,6 +87,7 @@ def attention( # # value_cache => [num_blocks, num_heads, head_size, block_size] + input_lengths = seqlen.input_lengths + seqlen.cache_lengths block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE From 77b0ce67f1c311a3ed1f98da32cdf3e71510c846 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 17 Oct 2024 12:51:13 -0700 Subject: [PATCH 03/34] WIP: server --- server/lorax_server/models/flash_causal_lm.py | 794 +++++++++++------- server/lorax_server/utils/state.py | 25 + 2 files changed, 532 insertions(+), 287 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 2626d30b8..c9586a013 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -32,7 +32,7 @@ from lorax_server.utils.segments import SegmentConcatBuilder, find_segments from lorax_server.utils.sources import HUB from lorax_server.utils.sources.hub import weight_files -from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER, PREFIX_CACHING, get_speculative_tokens, warmup_mode +from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER, PREFIX_CACHING, get_speculative_tokens, get_supports_chunking, warmup_mode from lorax_server.utils.tokenizer import TokenizerManager from lorax_server.utils.weights import Weights @@ -54,53 +54,64 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping: Dict[int, int] # Decoder values - input_ids: torch.Tensor - position_ids: torch.Tensor + # Can be a list for easy filtering + # If `input_ids` is a list, it needs to be materialized to a tensor first + input_ids: Union[torch.Tensor, List[List[int]]] + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + position_ids: Optional[torch.Tensor] # Spculative decoding values speculative_ids: Optional[torch.Tensor] - # Flash Attention values - - # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill - cu_seqlen_prefill: Optional[torch.Tensor] - - # Paged Attention values - - # Set when creating the batch - # CPU tensor of length b indicating the start of each sequence in slots - start_slots: torch.Tensor # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode - slot_indices: torch.Tensor + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + slot_indices: Optional[torch.Tensor] # list of length b of list of length s_i // block_size block_tables: List[List[int]] # tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences block_tables_tensor: torch.Tensor # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences - slots: torch.Tensor + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + slots: Optional[torch.Tensor] - # size [b], containing the number of blocks that can be retrieved from the cache - prefix_lens: List[int] - prefix_lens_tensor: torch.Tensor + max_input_length: int + max_current_length: int - max_seqlen: int + # Whether this batch contains at least one request that is prefilling + prefilling: bool + # Whether each request is prefilling + prefilling_mask: List[bool] # Prefill metadata tensors to efficiently compute logprobs + # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill + cu_seqlen_prefill: Optional[torch.Tensor] + # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers + # as we only keep SLIDING_WINDOW values instead of the whole tensor + prefill_cache_indices: Optional[torch.Tensor] + # Will be set by `generate_token` and reset after each prefill forward prefill_head_indices: Optional[torch.Tensor] + # Will be set by `generate_token` and reset after each prefill forward prefill_next_token_indices: Optional[torch.tensor] + # Will be set by `generate_token` and reset after each prefill forward prefill_cu_outlens: Optional[List[int]] - - # Prefixes - prefix_ids: List[List[int]] + # Will be set by `generate_token` and reset after each prefill forward + prefill_logprob_tokens: List[Optional[NextTokens]] # All tokens all_input_ids: List[List[int]] all_input_ids_tensor: torch.Tensor - # Lengths of all generations present in the batch + # Lengths of all generations present in the batch input_lengths: List[int] - input_lengths_tensor: torch.Tensor + # size [b], containing the number of blocks that can be retrieved from the cache + cache_lengths: List[int] + prompt_lengths: List[int] + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + input_lengths_tensor: Optional[torch.Tensor] + cache_lengths_tensor: Optional[torch.Tensor] + prompt_lengths_tensor: torch.Tensor + prefix_offsets: List[Optional[int]] read_offsets: List[Optional[int]] @@ -109,6 +120,7 @@ class FlashCausalLMBatch(Batch): stopping_criterias: List[StoppingCriteria] # Adapter metadata for each request + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode adapter_meta: AdapterBatchMetadata # Number of blocks in this batch @@ -116,17 +128,19 @@ class FlashCausalLMBatch(Batch): # Maximum number of blocks max_blocks: int - # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers - # as we only keep SLIDING_WINDOW values instead of the whole tensor - prefill_cache_indices: Optional[torch.Tensor] = None - def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.num_blocks * BLOCK_SIZE, + current_tokens=( + sum([len(i) for i in self.input_ids]) + if isinstance(self.input_ids, list) + else len(self.input_ids) + ), ) + @classmethod def to_pb_embed(self, batch, embeddings) -> generate_pb2.EmbedResponse: @@ -164,45 +178,28 @@ def from_pb( batch_tokenized_inputs = tokenizer(batch_inputs, truncation=True, max_length=max_truncation)[ "input_ids" ] + + speculative_tokens = get_speculative_tokens() - position_ids = [] - cu_seqlen_prefill = [0] - start_slots = [] - slot_indices = [] - prefill_cache_indices = [] - + cache_lengths = [] input_lengths = [] + prompt_lengths = [] prefix_offsets = [] read_offsets = [] all_input_ids = [] - prefix_ids = [] + all_postfix_ids = [] requests_idx_mapping = {} - all_prefill_logprobs = True - no_prefill_logprobs = True - prefill_head_indices = [] - prefill_next_token_indices = [] - prefill_cu_outlens = [0] - next_token_chooser_parameters = [] stopping_criterias = [] - adapter_indices_list = [] - adapter_set = set() - - # Cumulative length - cumulative_length = 0 - cumulative_slot_tokens = 0 - prefill_out_cumulative_length = 0 - num_blocks = 0 - max_seqlen = 0 + max_input_length = 0 + max_current_length = 0 max_length = 0 max_blocks = 0 block_tables = [] - slots = [] - prefix_lens = [] # Parse batch for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)): @@ -211,107 +208,87 @@ def from_pb( tokenized_input = tokenized_input[-r.truncate :] - orig_input_length = len(tokenized_input) - if PREFIX_CACHING: - prefix_len = r.prefix_len - if prefix_len == orig_input_length: - assert prefix_len > 0 - prefix_len -= 1 + prompt_length = len(tokenized_input) + prompt_lengths.append(prompt_length) + + cache_length = r.cache_len + assert ( + cache_length <= prompt_length + ), f"Prefix {cache_length} vs input {prompt_length}" + if cache_length == prompt_length: + assert False, "unreachable" + + # TODO(travis): double-check prefix caching + # if PREFIX_CACHING: + # prefix_len = r.prefix_len + # if prefix_len == orig_input_length: + # assert prefix_len > 0 + # prefix_len -= 1 + # else: + # prefix_len = 0 + + # `chunk_len` is an optional field in the protobuf + # It is only set if the model support chunking + if r.HasField("chunk_len"): + input_length = r.chunk_len + + if cache_length + input_length < prompt_length: + # FIXME: speculate is not supported for context chunking at the moment + assert speculative_tokens == 0 + assert get_supports_chunking() + assert input_length > 0 + + postfix_ids = tokenized_input[ + cache_length : cache_length + input_length + ] + assert ( + len(postfix_ids) == input_length + ), "Rust and Python tokenizers are not aligned" else: - prefix_len = 0 + # Use all the remaining ids + postfix_ids = tokenized_input[cache_length:] + input_length = len(postfix_ids) - prefix_ids.append(tokenized_input[:prefix_len]) - tokenized_input = tokenized_input[prefix_len:] - - input_length = len(tokenized_input) input_lengths.append(input_length) - prefix_offsets.append(input_length - 5) - read_offsets.append(input_length) + prefix_offsets.append(prompt_length - 5) + read_offsets.append(prompt_length) + all_postfix_ids.append(postfix_ids) all_input_ids.append(tokenized_input) - # Position ids - request_position_ids = torch.arange(prefix_len, orig_input_length, dtype=torch.int32) - position_ids.append(request_position_ids) - - # Add cumulative lengths of all previous inputs - cu_seqlen_prefill.append(cumulative_length + input_length) - next_token_chooser_parameters.append(r.parameters) stopping_criteria = StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) max_new_tokens = stopping_criteria.max_new_tokens stopping_criterias.append(stopping_criteria) - adapter_indices_list.append(torch.full((input_length,), r.adapter_index)) - adapter_set.add(r.adapter_index) - - speculative_tokens = get_speculative_tokens() + # adapter_indices_list.append(torch.full((input_length,), r.adapter_index)) + # adapter_set.add(r.adapter_index) # Tokens that need to be mapped to blocks. # Remove one as the first token des not have a past - block_tokens = orig_input_length + max_new_tokens - 1 + speculative_tokens - - # Tokens that need to be mapped to slots. We don't need slots for the - # cached prefix (if present). - slot_tokens = input_length + max_new_tokens - 1 + speculative_tokens + block_tokens = prompt_length + max_new_tokens - 1 + speculative_tokens # blocks and slots can be empty (for example in warmup) if not r.blocks: needed_blocks = math.ceil(block_tokens / BLOCK_SIZE) request_blocks = [b for b in range(num_blocks, num_blocks + needed_blocks)] - request_slots = [s for b in request_blocks for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)] else: request_blocks = r.blocks - request_slots = r.slots[ - prefix_len: #: orig_input_length + max_new_tokens + speculative_length - ] block_tables.append(request_blocks) - slots.extend(request_slots) - prefix_lens.append(prefix_len) + cache_lengths.append(cache_length) num_blocks += len(request_blocks) - start_slots.append(cumulative_slot_tokens) - - request_slot_indices = torch.arange( - cumulative_slot_tokens, - cumulative_slot_tokens + input_length, - dtype=torch.int64, - ) - slot_indices.append(request_slot_indices) - - # Create tensor to slice into the kv tensor in prefill - if SLIDING_WINDOW is not None: - request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, input_length - SLIDING_WINDOW), - cumulative_length + input_length, - dtype=torch.int64, - ) - prefill_cache_indices.append(request_prefill_cache_indices) - - all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs - no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs - - if r.prefill_logprobs: - prefill_head_indices.append(request_position_ids + cumulative_length) - prefill_next_token_indices.append(prefill_out_cumulative_length + input_length - 1) - prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) - prefill_out_cumulative_length += input_length - else: - prefill_head_indices.append(torch.tensor([cumulative_length + input_length - 1], dtype=torch.int32)) - prefill_next_token_indices.append(prefill_out_cumulative_length) - prefill_cu_outlens.append(prefill_out_cumulative_length + 1) - prefill_out_cumulative_length += 1 # Update - cumulative_length += input_length - cumulative_slot_tokens += slot_tokens - max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, len(request_blocks)) - max_length = max(max_length, input_length + max_new_tokens + speculative_tokens) - - adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) + max_input_length = max(max_input_length, input_length) + max_current_length = max(max_current_length, cache_length + input_length) + max_length = max( + max_length, + prompt_length + max_new_tokens + speculative_tokens, + ) # always use the base model tokenizer for the next token chooser until we revisit adding back support # for per-request tokenizers @@ -319,7 +296,6 @@ def from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, request_tokenizers, dtype, device ) - start_slots = torch.tensor(start_slots, dtype=torch.int64) # Padded all_input_ids_tensor all_input_ids_tensor = np.zeros((len(all_input_ids), max_length), dtype=np.int64) @@ -329,85 +305,50 @@ def from_pb( # Create tensors on device all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64, device=device) - if len(pb.requests) > 1: - input_ids = np.concatenate(all_input_ids, dtype=np.int64) - position_ids = torch.cat(position_ids) - slot_indices = torch.cat(slot_indices) - if SLIDING_WINDOW is not None: - prefill_cache_indices = torch.cat(prefill_cache_indices) - else: - input_ids = all_input_ids[0] - position_ids = position_ids[0] - slot_indices = slot_indices[0] - if SLIDING_WINDOW is not None: - prefill_cache_indices = prefill_cache_indices[0] - - cu_seqlen_prefill = torch.tensor(cu_seqlen_prefill, device=device, dtype=torch.int32) - - position_ids = position_ids.to(device) - slot_indices = slot_indices.to(device) - if SLIDING_WINDOW is not None: - prefill_cache_indices = prefill_cache_indices.to(device) - input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - input_lengths_tensor = torch.tensor(input_lengths, dtype=torch.int32, device=device) - - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) - - if all_prefill_logprobs: - prefill_head_indices = None - prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 - elif no_prefill_logprobs: - prefill_head_indices = cu_seqlen_prefill[1:] - 1 - prefill_next_token_indices = None - else: - prefill_head_indices = torch.tensor(torch.cat(prefill_head_indices), dtype=torch.int64, device=device) - prefill_next_token_indices = torch.tensor(prefill_next_token_indices, dtype=torch.int64, device=device) - - slots = torch.tensor(slots, dtype=torch.int64, device=device) block_tables_tensor = torch.zeros((len(block_tables), max_blocks), dtype=torch.int32, device="cpu") for i, request_blocks in enumerate(block_tables): block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor = block_tables_tensor.to(device) - prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device) + prompt_lengths_tensor = torch.tensor( + prompt_lengths, dtype=torch.int32, device=device + ) return cls( batch_id=pb.id, requests=pb.requests, requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - position_ids=position_ids, - speculative_ids=None, - cu_seqlen_prefill=cu_seqlen_prefill, - start_slots=start_slots, - slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, - slots=slots, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, - max_seqlen=max_seqlen, - prefill_head_indices=prefill_head_indices, - prefill_next_token_indices=prefill_next_token_indices, - prefill_cu_outlens=prefill_cu_outlens, + cache_lengths=cache_lengths, + max_input_length=max_input_length, + max_current_length=max_current_length, + prefilling=True, + prefilling_mask=[True] * len(pb.requests), + prefill_logprob_tokens=[None] * len(pb.requests), input_lengths=input_lengths, - input_lengths_tensor=input_lengths_tensor, + prompt_lengths=prompt_lengths, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, num_blocks=num_blocks, max_blocks=max_blocks, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), - prefill_cache_indices=prefill_cache_indices if SLIDING_WINDOW is not None else None, + speculative_ids=None, + prompt_lengths_tensor=prompt_lengths_tensor, + # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` + position_ids=None, + cu_seqlen_prefill=None, + prefill_cache_indices=None, + slot_indices=None, + slots=None, + prefill_head_indices=None, + prefill_next_token_indices=None, + prefill_cu_outlens=None, + cache_lengths_tensor=None, + input_lengths_tensor=None, + adapter_meta=None, ) @classmethod @@ -431,7 +372,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": if len(request_ids) == len(self): return self - device = self.input_ids.device + device = self.block_tables_tensor.device # New values after filtering requests_idx_mapping = {} @@ -444,19 +385,23 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": # Create on CPU to only move to GPU once instead of at every copy slot_indices = torch.empty(len(request_ids), dtype=torch.int64) - max_seqlen = 0 + max_input_length = 0 + max_current_length = 0 requests = [] - start_slots = [] block_tables = [] all_input_ids = [] - prefix_ids = [] + input_ids = [] + prompt_lengths = [] input_lengths = [] - prefix_lens = [] + cache_lengths = [] prefix_offsets = [] read_offsets = [] + prefilling_mask = [] + prefill_logprob_tokens = [] + stopping_criterias = [] adapter_set = set() @@ -472,62 +417,97 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": requests.append(self.requests[idx]) + # Prefilling + request_prefilling = self.prefilling_mask[idx] + prefilling_mask.append(request_prefilling) + # Get length request_input_length = self.input_lengths[idx] - prefix_len = self.prefix_lens[idx] - max_seqlen = max(max_seqlen, request_input_length) + request_cache_length = self.cache_lengths[idx] + max_input_length = max(max_input_length, request_input_length) + max_current_length = max( + max_current_length, request_cache_length + request_input_length + ) all_input_ids.append(self.all_input_ids[idx]) - prefix_ids.append(self.prefix_ids[idx]) + prompt_lengths.append(self.prompt_lengths[idx]) input_lengths.append(request_input_length) - prefix_lens.append(prefix_len) + cache_lengths.append(request_cache_length) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) - adapter_set.add(self.requests[idx].adapter_index) + prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx]) - remaining_tokens = stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + adapter_set.add(self.requests[idx].adapter_index) request_block_table = self.block_tables[idx] num_blocks += len(request_block_table) block_tables.append(request_block_table) - start_slots.append(cumulative_max_length) - # Copy to tensor (CPU) - slot_indices[i] = cumulative_max_length + request_input_length - 1 + # Input ids if the request was part of a prefilling batch + # If the batch was decoding we can index into the tensor directly later + if self.prefilling: + input_ids.append(self.input_ids[idx]) + else: + # Copy to tensor (CPU) + slot_indices[i] = cumulative_max_length + + remaining_tokens = ( + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) - # Set slice - slot_filtering_indices[ - self.start_slots[idx] : self.start_slots[idx] + request_input_length + remaining_tokens - 1 - ] = True + # Set slice + slot_filtering_indices[ + self.slot_indices[idx] : self.slot_indices[idx] + + request_input_length + + remaining_tokens + - 1 + ] = True - cumulative_max_length += request_input_length + remaining_tokens - 1 + cumulative_max_length += request_input_length + remaining_tokens - 1 max_blocks = max(max_blocks, len(request_block_table)) - # Index into tensors - input_ids = self.input_ids[indices] - position_ids = self.position_ids[indices] - adapter_indices = self.adapter_meta.adapter_indices[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices] - input_lengths_tensor = self.input_lengths_tensor[indices] - slots = self.slots[slot_filtering_indices] - prefix_lens_tensor = self.prefix_lens_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None - - start_slots = torch.tensor(start_slots, dtype=torch.int64) - - # Move to GPU now that we have the whole tensor - slot_indices = slot_indices.to(device) - - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) + prompt_lengths_tensor = self.prompt_lengths_tensor[indices] + + if self.prefilling: + # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` + position_ids = None + slot_indices = None + slots = None + cache_lengths_tensor = None + input_lengths_tensor = None + adapter_meta = None + else: + # Index into tensors + input_ids = self.input_ids[indices] + position_ids = self.position_ids[indices] + adapter_indices = self.adapter_meta.adapter_indices[indices] + input_lengths_tensor = self.input_lengths_tensor[indices] + slots = self.slots[slot_filtering_indices] + cache_lengths_tensor = self.cache_lengths_tensor[indices] + + # Move to GPU now that we have the whole tensor + slot_indices = slot_indices.to(device) + + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) return type(self)( batch_id=self.batch_id, @@ -537,34 +517,33 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": position_ids=position_ids, speculative_ids=speculative_ids, cu_seqlen_prefill=None, - start_slots=start_slots, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, - max_seqlen=max_seqlen, + max_input_length=max_input_length, + max_current_length=max_current_length, + prefilling=self.prefilling, + prefilling_mask=prefilling_mask, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, + prefill_logprob_tokens=prefill_logprob_tokens, + prompt_lengths=prompt_lengths, + prompt_lengths_tensor=prompt_lengths_tensor, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + cache_lengths=cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, num_blocks=num_blocks, max_blocks=max_blocks, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), + adapter_meta=adapter_meta, ) @classmethod @@ -574,59 +553,88 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch requests = [] requests_idx_mapping = {} + prefilling = False num_blocks = 0 total_batch_size = 0 total_slots = 0 max_blocks = 0 max_length = 0 - max_seqlen = 0 + max_input_length = 0 + max_current_length = 0 for b in batches: total_batch_size += len(b) - total_slots += len(b.slots) - num_blocks += b.num_blocks max_blocks = max(max_blocks, b.max_blocks) - max_seqlen = max(max_seqlen, b.max_seqlen) + # If `b` is prefilling and was just filtered, `b.slots` is None + # `total_slots` is not used if any of the batches is prefilling + total_slots += len(b.slots) if not b.prefilling else 0 + num_blocks += b.num_blocks + max_input_length = max(max_input_length, b.max_input_length) + max_current_length = max(max_current_length, b.max_current_length) speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 max_length = max( max_length, max( - input_length + prompt_length + stopping_criteria.max_new_tokens + speculative_length - stopping_criteria.current_tokens - for input_length, stopping_criteria in zip(b.input_lengths, b.stopping_criterias) + for prompt_length, stopping_criteria in zip( + b.prompt_lengths, b.stopping_criterias + ) ), ) + prefilling = prefilling or b.prefilling + + if prefilling: + input_ids = [] + # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` + position_ids = None + slots = None + slot_indices = None + cache_lengths_tensor = None + input_lengths_tensor = None + adapter_meta = None + adapter_segment_builder = None + else: + input_ids = batches[0].input_ids.new_empty(total_batch_size) + position_ids = batches[0].position_ids.new_empty(total_batch_size) + slots = batches[0].slots.new_empty(total_slots) + slot_indices = batches[0].slot_indices.new_empty(total_batch_size) + input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( + total_batch_size + ) + cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty( + total_batch_size + ) + total_indices_size = sum( + b.adapter_meta.adapter_indices.shape[0] for b in batches + ) + adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( + total_indices_size + ) + adapter_segment_builder = SegmentConcatBuilder() + adapter_set = set() - input_ids = batches[0].input_ids.new_empty(total_batch_size) - position_ids = batches[0].position_ids.new_empty(total_batch_size) - slots = batches[0].slots.new_empty(total_slots) - slot_indices = batches[0].slot_indices.new_empty(total_batch_size) - input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(total_batch_size) + prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(total_batch_size) block_tables_tensor = batches[0].block_tables_tensor.new_zeros((total_batch_size, max_blocks)) - prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size) all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros((total_batch_size, max_length)) - total_indices_size = sum(b.adapter_meta.adapter_indices.shape[0] for b in batches) - - adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(total_indices_size) - adapter_set = set() - adapter_segment_builder = SegmentConcatBuilder() - - start_slots = [] block_tables = [] - prefix_lens = [] + cache_lengths = [] all_input_ids = [] - prefix_ids = [] + prompt_lengths = [] input_lengths = [] prefix_offsets = [] read_offsets = [] + prefill_logprob_tokens = [] + next_token_chooser_parameters = [] sequence_processors = [] stopping_criterias = [] + prefilling_mask = [] # Cumulative length cumulative_batch_size = 0 @@ -645,26 +653,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) - slots_start_index = cumulative_slots - slots_end_index = cumulative_slots + len(batch.slots) # Copy tensors (GPU) - input_ids[start_index:end_index] = batch.input_ids - position_ids[start_index:end_index] = batch.position_ids - slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots - input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor - slots[slots_start_index:slots_end_index] = batch.slots - - # Copy over adapter indices - adapter_start_index = cumulative_adapter_indices_size - adapter_end_index = cumulative_adapter_indices_size + batch.adapter_meta.adapter_indices.shape[0] - adapter_indices[adapter_start_index:adapter_end_index] = batch.adapter_meta.adapter_indices - cumulative_adapter_indices_size = adapter_end_index - adapter_set.update(batch.adapter_meta.adapter_set) - - # Update adapter segments - adapter_segment_builder.concat(batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices) - all_input_ids_tensor[start_index:end_index, : batch.all_input_ids_tensor.shape[1]] = ( batch.all_input_ids_tensor[:, :max_length] ) @@ -673,19 +663,56 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch batch.block_tables_tensor[:, :max_blocks] ) - prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor + prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor - start_slots.append(batch.start_slots + cumulative_slots) + if not prefilling: + slots_start_index = cumulative_slots + slots_end_index = cumulative_slots + len(batch.slots) + input_ids[start_index:end_index] = batch.input_ids + position_ids[start_index:end_index] = batch.position_ids + slots[slots_start_index:slots_end_index] = batch.slots + slot_indices[start_index:end_index] = ( + batch.slot_indices + cumulative_slots + ) + input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor + cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor + + # Copy over adapter indices + adapter_start_index = cumulative_adapter_indices_size + adapter_end_index = ( + cumulative_adapter_indices_size + + batch.adapter_meta.adapter_indices.shape[0] + ) + adapter_indices[adapter_start_index:adapter_end_index] = ( + batch.adapter_meta.adapter_indices + ) + cumulative_adapter_indices_size = adapter_end_index + adapter_set.update(batch.adapter_meta.adapter_set) + adapter_segment_builder.concat( + batch.adapter_meta.adapter_segments, + batch.adapter_meta.segment_indices, + ) + + # Update + cumulative_slots += len(batch.slots) + else: + if isinstance(batch.input_ids, torch.Tensor): + batch.input_ids = batch.input_ids.view(-1, 1).tolist() + input_ids.extend(batch.input_ids) + + prefilling_mask.extend(batch.prefilling_mask) block_tables.extend(batch.block_tables) - prefix_lens.extend(batch.prefix_lens) + cache_lengths.extend(batch.cache_lengths) all_input_ids.extend(batch.all_input_ids) - prefix_ids.extend(batch.prefix_ids) + prompt_lengths.extend(batch.prompt_lengths) input_lengths.extend(batch.input_lengths) prefix_offsets.extend(batch.prefix_offsets) read_offsets.extend(batch.read_offsets) + prefill_logprob_tokens.extend(batch.prefill_logprob_tokens) + next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) if batch.next_token_chooser.schema_processor is not None: sequence_processors.extend(batch.next_token_chooser.schema_processor.sequence_processors) @@ -696,9 +723,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch # Update cumulative_batch_size += len(batch) - cumulative_slots += len(batch.slots) - - start_slots = torch.concat(start_slots) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, @@ -712,7 +736,14 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None ) - adapter_segments, adapter_segment_indices = adapter_segment_builder.build() + if adapter_segment_builder is not None: + adapter_segments, adapter_segment_indices = adapter_segment_builder.build() + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) return cls( batch_id=batches[0].batch_id, @@ -722,34 +753,221 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch position_ids=position_ids, speculative_ids=speculative_ids, cu_seqlen_prefill=None, - start_slots=start_slots, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + cache_lengths=cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, slots=slots, - max_seqlen=max_seqlen, + max_input_length=max_input_length, + max_current_length=max_current_length, + prefilling=prefilling, + prefilling_mask=prefilling_mask, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, + prefill_logprob_tokens=prefill_logprob_tokens, + prompt_lengths=prompt_lengths, + prompt_lengths_tensor=prompt_lengths_tensor, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, num_blocks=num_blocks, max_blocks=max_blocks, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), + adapter_meta=adapter_meta, + ) + + def prepare_for_prefill(self): + # Prepare values if we need to continue prefilling + # Speculation must be ignored while we prefill even with chunking + # it simplifies everything + assert self.speculative_ids is None + + sliding_window = get_sliding_windows() + position_ids = [] + cu_seqlen_prefill = [0] + slot_indices = [] + prefill_cache_indices = [] + all_prefill_logprobs = True + no_prefill_logprobs = True + prefill_head_indices = [] + prefill_next_token_indices = [] + prefill_cu_outlens = [0] + + # Cumulative length + cumulative_length = 0 + cumulative_slot_tokens = 0 + prefill_out_cumulative_length = 0 + + slots = [] + adapter_indices_list = [] + adapter_set = set() + + for i, ( + r, + cache_length, + input_length, + prompt_length, + request_prefilling, + blocks, + ) in enumerate( + zip( + self.requests, + self.cache_lengths, + self.input_lengths, + self.prompt_lengths, + self.prefilling_mask, + self.block_tables, + ) + ): + next_chunk_length = input_length + # Position ids + request_position_ids = torch.arange( + cache_length, cache_length + input_length, dtype=torch.int32 + ) + position_ids.append(request_position_ids) + + # Add cumulative lengths of all previous inputs + cu_seqlen_prefill.append(cumulative_length + input_length) + + if not r.slots: + request_slots = [ + s + for b in blocks + for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) + ] + else: + request_slots = r.slots + + request_slots = request_slots[cache_length:] + request_slot_indices = torch.arange( + cumulative_slot_tokens, + cumulative_slot_tokens + input_length, + dtype=torch.int64, + ) + + # Create tensor to slice into the kv tensor in prefill + if sliding_window is not None: + request_prefill_cache_indices = torch.arange( + cumulative_length + max(0, input_length - sliding_window), + cumulative_length + input_length, + dtype=torch.int64, + ) + + # Prefill logprobs is ignored if the request is done prefilling + prefill_logprobs = r.prefill_logprobs and request_prefilling + + all_prefill_logprobs = all_prefill_logprobs and prefill_logprobs + no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs + + if prefill_logprobs: + prefill_head_indices.append( + torch.arange( + cumulative_length, + cumulative_length + input_length, + dtype=torch.int64, + ) + ) + prefill_next_token_indices.append( + prefill_out_cumulative_length + input_length - 1 + ) + prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) + prefill_out_cumulative_length += input_length + else: + prefill_head_indices.append( + torch.tensor( + [cumulative_length + input_length - 1], + dtype=torch.int64, + ) + ) + prefill_next_token_indices.append(prefill_out_cumulative_length) + prefill_cu_outlens.append(prefill_out_cumulative_length + 1) + prefill_out_cumulative_length += 1 + + slots.extend(request_slots) + slot_indices.append(request_slot_indices) + + if sliding_window is not None: + prefill_cache_indices.append(request_prefill_cache_indices) + + ADAPTER_TO_INDEX = get_adapter_to_index() + adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) + adapter_indices_list.append(torch.full((next_chunk_length,), adapter_index)) + adapter_set.add(adapter_index) + + # Update + cumulative_length += next_chunk_length + cumulative_slot_tokens += len(request_slots) + + device = self.block_tables_tensor.device + + if isinstance(self.input_ids, list): + if len(self) > 1: + input_ids = np.concatenate(self.input_ids, dtype=np.int64) + else: + input_ids = self.input_ids[0] + self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + + if len(self) > 1: + position_ids = torch.cat(position_ids) + slot_indices = torch.cat(slot_indices) + if sliding_window is not None: + prefill_cache_indices = torch.cat(prefill_cache_indices) + else: + position_ids = position_ids[0] + slot_indices = slot_indices[0] + if sliding_window is not None: + prefill_cache_indices = prefill_cache_indices[0] + + self.prefill_cu_outlens = prefill_cu_outlens + cu_seqlen_prefill = torch.tensor( + cu_seqlen_prefill, device=device, dtype=torch.int32 + ) + self.cu_seqlen_prefill = cu_seqlen_prefill + self.position_ids = position_ids.to(device) + self.slot_indices = slot_indices.to(device) + self.prefill_cache_indices = ( + prefill_cache_indices.to(device) if sliding_window is not None else None + ) + self.input_lengths_tensor = torch.tensor( + self.input_lengths, dtype=torch.int32, device=device + ) + + if all_prefill_logprobs: + prefill_head_indices = None + prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 + elif no_prefill_logprobs: + prefill_head_indices = cu_seqlen_prefill[1:] - 1 + prefill_next_token_indices = None + else: + prefill_head_indices = torch.cat(prefill_head_indices).to(device) + prefill_next_token_indices = torch.tensor( + prefill_next_token_indices, dtype=torch.int64, device=device + ) + + self.prefill_head_indices = prefill_head_indices + self.prefill_next_token_indices = prefill_next_token_indices + self.slots = torch.tensor(slots, dtype=torch.int64, device=device) + self.cache_lengths_tensor = torch.tensor( + self.cache_lengths, dtype=torch.int32, device=device + ) + adapter_indices = torch.cat(adapter_indices_list).to( + dtype=torch.int64, device=device + ) + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + self.adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, ) def __len__(self): @@ -777,6 +995,7 @@ def __init__( embedding_dim: Optional[int] = None, trust_remote_code: bool = False, processor=None, + supports_chunking: bool = True, ): global SLIDING_WINDOW global SLIDING_WINDOW_BLOCKS @@ -899,6 +1118,7 @@ def __init__( dynamic_adapter_loading_enabled=not merge_adapter_weights, trust_remote_code=trust_remote_code, processor=processor, + supports_chunking=supports_chunking, ) if sliding_window is not None: diff --git a/server/lorax_server/utils/state.py b/server/lorax_server/utils/state.py index 63e73f55f..809c756a2 100644 --- a/server/lorax_server/utils/state.py +++ b/server/lorax_server/utils/state.py @@ -1,5 +1,6 @@ import os from contextlib import contextmanager +from typing import Optional from loguru import logger @@ -17,6 +18,10 @@ logger.info("Using flashinfer") +SUPPORTS_CHUNKING: Optional[bool] = None +MAX_PREFILL_TOKENS: Optional[int] = None + + BLOCK_SIZE: int if FLASH_INFER: BLOCK_SIZE = 1 @@ -49,3 +54,23 @@ def set_speculative_tokens(value: int): def get_speculative_tokens() -> int: return SPECULATIVE_TOKENS + + +def set_supports_chunking(supports_chunking: bool): + global SUPPORTS_CHUNKING + SUPPORTS_CHUNKING = supports_chunking + + +def get_supports_chunking() -> bool: + global SUPPORTS_CHUNKING + return SUPPORTS_CHUNKING + + +def set_max_prefill_tokens(max_prefill_tokens: int): + global MAX_PREFILL_TOKENS + MAX_PREFILL_TOKENS = max_prefill_tokens + + +def get_max_prefill_tokens() -> int: + global MAX_PREFILL_TOKENS + return MAX_PREFILL_TOKENS From b02a01dab0809568b3bd49a6d66f7158bdfb4098 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 17 Oct 2024 13:16:28 -0700 Subject: [PATCH 04/34] Graph update --- server/lorax_server/models/flash_causal_lm.py | 47 +++++++---- server/lorax_server/utils/attention/utils.py | 14 ++-- server/lorax_server/utils/graph.py | 83 +++++++++++-------- 3 files changed, 88 insertions(+), 56 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index c9586a013..2cf976968 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Any, ContextManager, Dict, List, Optional, Tuple, Type, Union +from lorax_server.utils.attention.common import Seqlen import numpy as np import torch import torch.distributed @@ -1330,8 +1331,8 @@ def _forward_context( cu_seqlen_prefill: Optional[torch.Tensor], input_lengths: List[int], input_lengths_tensor: torch.Tensor, - prefix_lens: List[int], - prefix_lens_tensor: torch.Tensor, + cache_lens: List[int], + cache_lens_tensor: torch.Tensor, state: Optional[Any] = None, ) -> ContextManager: if not FLASH_INFER: @@ -1387,11 +1388,12 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> input_ids = batch.input_ids position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - prefix_lens_tensor = batch.prefix_lens_tensor - max_s = batch.max_seqlen + cache_lengths_tensor = batch.cache_lengths_tensor + max_s = batch.max_current_length if batch.speculative_ids is not None: speculative_ids = batch.speculative_ids @@ -1404,23 +1406,36 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - prefix_lens_tensor = (batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)).reshape(-1) + cache_lengths_tensor = (batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)).reshape(-1) block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B * new_length, -1).contiguous() max_s = max_s + speculative_length input_ids = new_input_ids position_ids = new_position_ids - + + if cu_seqlen_prefill is None and self.max_past() is not None: + # In decode, not prefill, we're actually overwriting the KV-cache + # in a circular buffer mode. + # This makes sure the max_s for the decode pass is correct. + max_s = min(self.max_past(), max_s) + + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=None, + max_q=batch.max_input_length, + max_k=batch.max_current_length, + ) + # Model Forward if not use_graph: # eager mode - input_lengths = input_lengths + prefix_lens_tensor if FLASH_INFER: block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + cache_lengths=batch.cache_lengths, ) with self._forward_context( @@ -1428,8 +1443,8 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> cu_seqlen_prefill=batch.cu_seqlen_prefill, input_lengths=batch.input_lengths, input_lengths_tensor=input_lengths, - prefix_lens=batch.prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + cache_lens=batch.cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, ): out = model.forward( input_ids=input_ids, @@ -1438,7 +1453,7 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, max_s=max_s, adapter_data=adapter_data, prefill_cache_indices=batch.prefill_cache_indices, @@ -1453,9 +1468,9 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, - prefix_lens=batch.prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + seqlen=seqlen, + cache_lens=batch.cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, max_s=max_s, adapter_data=adapter_data, prefill_cache_indices=batch.prefill_cache_indices, @@ -1471,7 +1486,9 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> def generate_token( self, batch: FlashCausalLMBatch, is_warmup: bool = False ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: - prefill = batch.cu_seqlen_prefill is not None + prefill = batch.prefilling + if prefill: + batch.prepare_for_prefill() prefill_logprobs = batch.prefill_next_token_indices is not None return_alternatives = any(req.parameters.return_k_alternatives > 0 for req in batch.requests) diff --git a/server/lorax_server/utils/attention/utils.py b/server/lorax_server/utils/attention/utils.py index d2c767246..80ffcbda2 100644 --- a/server/lorax_server/utils/attention/utils.py +++ b/server/lorax_server/utils/attention/utils.py @@ -4,17 +4,19 @@ def block_tables_to_ragged( - *, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int] + *, block_tables: torch.Tensor, input_lengths: List[int], cache_lengths: List[int] ) -> torch.Tensor: """Convert block table to ragged format compatible with FlashInfer.""" - assert len(input_lengths) == len(prefix_lens) + assert len(input_lengths) == len(cache_lengths) - total_len = sum(input_lengths) + sum(prefix_lens) - block_tables_ragged = torch.empty(total_len, dtype=torch.int32, device=block_tables.device) + total_len = sum(input_lengths) + sum(cache_lengths) + block_tables_ragged = torch.empty( + total_len, dtype=torch.int32, device=block_tables.device + ) offset = 0 - for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)): - seq_len = prefix_len + input_length + for i, (input_length, cache_length) in enumerate(zip(input_lengths, cache_lengths)): + seq_len = cache_length + input_length block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] offset += seq_len diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index 4bd8c31cf..c16515d2e 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -7,6 +7,7 @@ from statistics import median from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple +from lorax_server.utils.attention.common import Seqlen import numpy as np import torch from loguru import logger @@ -82,9 +83,9 @@ class GraphState: position_ids: torch.Tensor block_tables: torch.Tensor slots: torch.Tensor - input_lengths: torch.Tensor - prefix_lens: List[int] - prefix_lens_tensor: torch.Tensor + seqlen: Seqlen + cache_lens: List[int] + cache_lens_tensor: torch.Tensor adapter_data: AdapterBatchData traced_adapter_layer_names: Set[str] state: Any = None @@ -109,8 +110,8 @@ def get_max_graph_state( position_ids = torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int32, device=device) slots = torch.full((MAX_BATCH_SIZE,), SLOT_PAD_VALUE, dtype=torch.int64, device=device) input_lengths = torch.full((MAX_BATCH_SIZE,), max_total_tokens, dtype=torch.int32, device=device) - prefix_lens = [0] * MAX_BATCH_SIZE - prefix_lens_tensor = torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int32, device=device) + cache_lengths = [0] * MAX_BATCH_SIZE + cache_lengths_tensor = torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int32, device=device) adapter_weight_data = {} for layer_name in adapter_layers: @@ -140,9 +141,14 @@ def get_max_graph_state( position_ids=position_ids, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + seqlen=Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_total_tokens, + ), + cache_lengths_tensor=cache_lengths_tensor, adapter_data=AdapterBatchData( meta=AdapterBatchMetadata( adapter_indices=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), @@ -235,8 +241,8 @@ def trace( block_tables = max_input_state.block_tables[:batch_size] input_lengths = max_input_state.input_lengths[:batch_size] - prefix_lengths = max_input_state.prefix_lens[:batch_size] - prefix_lengths_tensor = max_input_state.prefix_lens_tensor[:batch_size] + cache_lengths = max_input_state.cache_lens[:batch_size] + cache_lengths_tensor = max_input_state.cache_lens_tensor[:batch_size] state = None if FLASH_INFER: @@ -247,7 +253,7 @@ def trace( block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=input_lengths.tolist(), - prefix_lens=prefix_lengths, + cache_lens=cache_lengths, ) block_tables_ptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) @@ -266,9 +272,15 @@ def trace( position_ids=max_input_state.position_ids[:batch_size], block_tables=block_tables, slots=max_input_state.slots[:batch_size], - input_lengths=input_lengths, - prefix_lens=prefix_lengths, - prefix_lens_tensor=prefix_lengths_tensor, + seqlen=Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_total_tokens, + ), + cache_lens=cache_lengths, + cache_lens_tensor=cache_lengths_tensor, adapter_data=AdapterBatchData( meta=AdapterBatchMetadata( adapter_indices=max_input_state.adapter_data.meta.adapter_indices[:batch_size], @@ -289,9 +301,9 @@ def trace( block_tables=input_state.block_tables, cu_seqlen_prefill=None, input_lengths=input_lengths, - input_lengths_tensor=input_state.input_lengths, - prefix_lens=prefix_lengths, - prefix_lens_tensor=prefix_lengths_tensor, + input_lengths_tensor=input_state.seqlen.input_lengths, + cache_lens=cache_lengths, + cache_lens_tensor=cache_lengths_tensor, state=input_state.state, ): # warmup @@ -302,7 +314,7 @@ def trace( kv_cache=kv_cache, block_tables=input_state.block_tables, slots=input_state.slots, - input_lengths=input_state.input_lengths, + seqlen=input_state.seqlen, max_s=max_total_tokens, adapter_data=input_state.adapter_data, prefill_cache_indices=None, @@ -319,7 +331,7 @@ def trace( kv_cache=kv_cache, block_tables=input_state.block_tables, slots=input_state.slots, - input_lengths=input_state.input_lengths, + seqlen=input_state.seqlen, max_s=max_total_tokens, adapter_data=input_state.adapter_data, prefill_cache_indices=None, @@ -338,9 +350,9 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, - prefix_lens: List[int], - prefix_lens_tensor: torch.Tensor, + seqlen: Seqlen, + cache_lens: List[int], + cache_lens_tensor: torch.Tensor, max_s: int, adapter_data: AdapterBatchData, lm_head_indices: Optional[torch.Tensor] = None, @@ -348,15 +360,16 @@ def forward( pad_and_fill(self.input_state.input_ids, input_ids, 0) pad_and_fill(self.input_state.position_ids, position_ids, 0) pad_and_fill(self.input_state.slots, slots, SLOT_PAD_VALUE) - pad_and_fill(self.input_state.input_lengths, input_lengths + prefix_lens_tensor, 0) - self.input_state.prefix_lens[: len(prefix_lens)] = prefix_lens - pad_and_fill(self.input_state.prefix_lens_tensor, prefix_lens_tensor, 0) + pad_and_fill(self.input_state.seqlen.input_lengths, seqlen.input_lengths, 0) + pad_and_fill(self.input_state.seqlen.cache_lengths, seqlen.cache_lengths, 0) + self.input_state.cache_lens[: len(cache_lens)] = cache_lens + pad_and_fill(self.input_state.cache_lens_tensor, cache_lens_tensor, 0) if FLASH_INFER: block_tables = block_tables_to_ragged( block_tables=block_tables, - input_lengths=input_lengths, - prefix_lens=prefix_lens, + input_lengths=seqlen.input_lengths, + cache_lens=seqlen.cache_lengths, ) self.input_state.block_tables[: block_tables.shape[0]] = block_tables else: @@ -390,8 +403,8 @@ def forward( cu_seqlen_prefill=None, input_lengths=input_lengths, input_lengths_tensor=self.input_state.input_lengths, - prefix_lens=self.input_state.prefix_lens, - prefix_lens_tensor=self.input_state.prefix_lens_tensor, + cache_lens=self.input_state.cache_lens, + cache_lens_tensor=self.input_state.cache_lens_tensor, state=self.input_state.state, ): self.graph.replay() @@ -541,9 +554,9 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, - prefix_lens: List[int], - prefix_lens_tensor: torch.Tensor, + seqlen: Seqlen, + cache_lens: List[int], + cache_lens_tensor: torch.Tensor, max_s: int, adapter_data: AdapterBatchData, lm_head_indices: Optional[torch.Tensor] = None, @@ -587,9 +600,9 @@ def forward( kv_cache=kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, - prefix_lens=prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + seqlen=seqlen, + cache_lens=cache_lens, + cache_lens_tensor=cache_lens_tensor, max_s=max_s, adapter_data=adapter_data, lm_head_indices=lm_head_indices, From 652eb8f697363a5467117eec655a71a4c9fb3ea9 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 17 Oct 2024 13:47:13 -0700 Subject: [PATCH 05/34] WIP: generate_tokens --- server/lorax_server/models/flash_causal_lm.py | 545 ++++++++++++------ 1 file changed, 374 insertions(+), 171 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 2cf976968..c481d106d 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -33,7 +33,7 @@ from lorax_server.utils.segments import SegmentConcatBuilder, find_segments from lorax_server.utils.sources import HUB from lorax_server.utils.sources.hub import weight_files -from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER, PREFIX_CACHING, get_speculative_tokens, get_supports_chunking, warmup_mode +from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER, PREFIX_CACHING, get_max_prefill_tokens, get_speculative_tokens, get_supports_chunking, warmup_mode from lorax_server.utils.tokenizer import TokenizerManager from lorax_server.utils.weights import Weights @@ -1520,8 +1520,61 @@ def generate_token( speculative_logits = ( speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits ) + if len(batch) > 1 and prefill_logprobs: + # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs + # When batch == 1, we will just use the batch.input_ids values directly + prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) else: + prefill_logprobs = None next_token_logits = out + next_adapter_indices = batch.adapter_meta.adapter_indices + + finished_prefilling = True + next_chunk_lengths = [] + current_prefilling_mask = batch.prefilling_mask + if prefill: + if get_supports_chunking(): + next_prefilling_mask = [] + # Budget in tokens for the next batch + # We remove (len(batch) - 1) to always have enough space for at least a single decode + # for the remaining requests -1 because the first request does not need to be removed from the budget + # (ex: you have one request in the batch, you want it to take the full budget not budget -1) + batch_budget = get_max_prefill_tokens() - (len(batch) - 1) + # We reverse to prioritize older requests + # zip() is not reversible so reverse the underlying lists instead + for cache_length, input_length, prompt_length in zip( + reversed(batch.cache_lengths), + reversed(batch.input_lengths), + reversed(batch.prompt_lengths), + ): + remaining_prefill_tokens = max( + prompt_length - cache_length - input_length, 0 + ) + if remaining_prefill_tokens > 0: + next_chunk_length = max( + min(remaining_prefill_tokens, batch_budget), 1 + ) + batch_budget -= next_chunk_length + finished_prefilling = False + next_prefilling_mask.append(True) + else: + # FIXME: use true number of accepted tokens instead of 1 + # Since speculation will be turned off, this is always true + next_chunk_length = 1 + next_prefilling_mask.append(False) + next_chunk_lengths.append(next_chunk_length) + + # Reverse back the obtained values² + next_chunk_lengths.reverse() + next_prefilling_mask.reverse() + else: + # The model does not support chunking + # We know we only do a single prefill + finished_prefilling = True + next_prefilling_mask = [False] * len(batch) + + batch.prefilling = not finished_prefilling + batch.prefilling_mask = next_prefilling_mask speculative_tokens = get_speculative_tokens() ( @@ -1530,7 +1583,7 @@ def generate_token( accepted_ids, speculative_ids, ) = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_seqlen], + batch.all_input_ids_tensor[:, : batch.max_current_length], next_token_logits, speculative_tokens, batch.speculative_ids, @@ -1542,37 +1595,25 @@ def generate_token( torch.log_softmax(next_token_logits, -1), dim=-1, stable=True, descending=True ) - if prefill: - if len(batch) > 1 and prefill_logprobs: - # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs - # When batch == 1, we will just use the batch.input_ids values directly - prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) - + # Since we are done prefilling, all the tensors that were concatenating values for all the requests + # instantly become of shape [BATCH_SIZE] + if prefill and finished_prefilling: next_position_ids = batch.position_ids.new_empty(len(batch)) batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] - # We do not need cu_seqlen_prefill anymore - batch.cu_seqlen_prefill = None - next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(len(batch)) - else: - prefill_logprobs = None + elif not prefill: next_position_ids = batch.position_ids - next_adapter_indices = batch.adapter_meta.adapter_indices - - # Cumulative length - cumulative_length = 0 - - # Results - generations: List[Generation] = [] - - # During warmup, do not allow early stopping - stopped = not is_warmup # Zipped iterator iterator = zip( + batch.requests, + batch.prompt_lengths, + batch.cache_lengths, batch.input_lengths, batch.all_input_ids, accepted_ids, + current_prefilling_mask, + batch.prefilling_mask, ) # We do two for loops as the first one can run completely asynchronously from the GPU while for the second @@ -1580,21 +1621,23 @@ def generate_token( # It is faster if we delay this sync for the maximum amount of time # For each member of the batch - idx = 0 + index = 0 + # Cumulative length + cumulative_length = 0 for i, ( + request, + prompt_length, + cache_length, input_length, all_input_ids, - num_accepted_ids, + n_accepted_ids, + request_was_prefilling, + request_is_prefilling, ) in enumerate(iterator): - # Indexing metadata - start_index = cumulative_length - end_index = cumulative_length + input_length - - if prefill: + if prefill and finished_prefilling: # Indexing metadata - out_start_index = batch.prefill_cu_outlens[i] - out_end_index = batch.prefill_cu_outlens[i + 1] - out_length = out_end_index - out_start_index + _start_index = cumulative_length + end_index = cumulative_length + input_length # Initialize position_ids # In decode, we do not need this as we can just increment position ids @@ -1605,33 +1648,55 @@ def generate_token( next_adapter_indices[i] = batch.adapter_meta.adapter_indices[end_index - 1] # Used to gather prefill logprobs - # Copy batch.input_ids to prefill_token_indices - if prefill_logprobs: + # Copy batch.all_input_ids_tensor to prefill_token_indices + if request.prefill_logprobs and request_was_prefilling: + # Indexing metadata + out_start_index = batch.prefill_cu_outlens[i] + out_end_index = batch.prefill_cu_outlens[i + 1] + + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + ids = batch.all_input_ids_tensor[ + i, cache_length + 1 : cache_length + input_length + 1 + ] if len(batch) > 1: - prefill_tokens_indices[out_start_index : out_end_index - 1] = batch.input_ids[ - start_index + 1 : start_index + out_length - ] + prefill_tokens_indices[out_start_index:out_end_index] = ids else: # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = batch.input_ids[start_index + 1 : start_index + out_length] + prefill_tokens_indices = ids - batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] + if not request_is_prefilling: + # Only save tokens if we are done prefilling for this request + for j in range(n_accepted_ids): + batch.all_input_ids_tensor[i, cache_length + input_length + j] = ( + next_input_ids[index + j] + ) - for j in range(num_accepted_ids): - batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[idx] - idx += 1 + batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] + index += n_accepted_ids cumulative_length += input_length - # Set values in batch - batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] - batch.position_ids = next_position_ids + accepted_ids - batch.adapter_meta.adapter_indices = next_adapter_indices - batch.speculative_ids = speculative_ids - batch.input_lengths_tensor += accepted_ids - batch.slot_indices += accepted_ids + # Update values + # These values can be updated without a GPU -> CPU sync + if not prefill or (prefill and finished_prefilling): + batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] + batch.speculative_ids = speculative_ids + batch.position_ids = next_position_ids + accepted_ids + batch.cache_lengths_tensor += batch.input_lengths_tensor + batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32) + batch.slot_indices += accepted_ids + batch.adapter_meta.adapter_indices = next_adapter_indices - if prefill: + if prefill and prefill_logprobs: + # Get prefill logprobs + prefill_logprobs_tensor = torch.log_softmax(out, -1) + prefill_logprobs = torch.gather(prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)) + # GPU <-> CPU sync + prefill_logprobs = prefill_logprobs.view(-1).tolist() + + # Does a GPU <-> CPU sync internally + if prefill and finished_prefilling: # adjust segment lengths to account for all request lengths being 1 during decoding adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) batch.adapter_meta.adapter_segments = torch.tensor( @@ -1640,170 +1705,305 @@ def generate_token( device=batch.adapter_meta.adapter_segments.device, ) - if prefill and prefill_logprobs: - # Get prefill logprobs - prefill_logprobs_tensor = torch.log_softmax(out, -1) - prefill_logprobs = torch.gather(prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)) - # GPU <-> CPU sync - prefill_logprobs = prefill_logprobs.view(-1).tolist() - # GPU <-> CPU sync next_token_logprobs = next_token_logprobs.tolist() next_token_ids = next_input_ids.tolist() + accepted_ids = accepted_ids.tolist() if return_alternatives: alternative_token_logprobs = alternative_token_logprobs.tolist() alternative_token_ids = alternative_token_ids.tolist() + + # Update values if we need to continue prefilling + # This represents the `else` case of the `Update values` if above + # but since this require the `next_token_ids` to be on CPU, it is better to do it here + if prefill and not finished_prefilling: + # Speculation must be ignored while we prefill even with chunking + # it simplifies everything + assert batch.speculative_ids is None + + all_postfix_ids = [] + for i, ( + request_prefilling, + next_token_id, + all_input_ids, + cache_length, + input_length, + next_chunk_length, + ) in enumerate( + zip( + batch.prefilling_mask, + next_token_ids, + batch.all_input_ids, + batch.cache_lengths, + batch.input_lengths, + next_chunk_lengths, + ) + ): + if request_prefilling: + next_cache_length = cache_length + input_length + # Get new prompt IDs to prefill + postfix_ids = all_input_ids[ + next_cache_length : next_cache_length + next_chunk_length + ] + else: + # This request is done prefilling, the new id is the one selected the sampling method + postfix_ids = [next_token_id] + + all_postfix_ids.append(postfix_ids) + + batch.input_ids = all_postfix_ids + + # Results + generations: List[Generation] = [] + stopped = not is_warmup # Zipped iterator iterator = zip( batch.requests, + batch.prompt_lengths, + batch.cache_lengths, batch.input_lengths, batch.prefix_offsets, batch.read_offsets, batch.stopping_criterias, batch.all_input_ids, - batch.prefix_ids, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, + current_prefilling_mask, + batch.prefilling_mask, accepted_ids, ) + # Reset max_input_length + batch.max_input_length = 0 # For each member of the batch - idx = 0 + index = 0 for i, ( request, + prompt_length, + cache_length, input_length, prefix_offset, read_offset, stopping_criteria, all_input_ids, - prefix_ids, do_sample, seed, - num_accepted_ids, + request_was_prefilling, + request_is_prefilling, + n_accepted_ids, ) in enumerate(iterator): all_alternative_tokens = [] if request.parameters.return_k_alternatives > 0 else None - next_token_texts = [] - left = 0 - current_stopped = False - for j in range(num_accepted_ids): - token_idx = idx + j - - # Generated token - next_token_id = next_token_ids[token_idx] - all_input_ids.append(next_token_id) - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids, - prefix_offset, - read_offset, - ) - next_token_texts.append(next_token_text) - if request.parameters.return_k_alternatives > 0: - # Limit the number of alternatives to the vocabulary size - num_alternatives = min( - request.parameters.return_k_alternatives, - len(alternative_token_ids[token_idx]), - ) + # TODO(travis): return_k_alternatives + # if request.parameters.return_k_alternatives > 0: + # # Limit the number of alternatives to the vocabulary size + # num_alternatives = min( + # request.parameters.return_k_alternatives, + # len(alternative_token_ids[token_idx]), + # ) + + # # Select top-k logprobs + # request_alternative_token_ids = alternative_token_ids[token_idx][:num_alternatives] + # request_alternative_token_logprobs = alternative_token_logprobs[token_idx][:num_alternatives] + + # # Decode tokens + # request_alternative_token_texts = [] + # for alternative_token_id in request_alternative_token_ids: + # all_input_ids.append(alternative_token_id) + # alternative_token_text, _, _ = self.decode_token( + # all_input_ids, + # prefix_offset, + # read_offset, + # ) + # request_alternative_token_texts.append(alternative_token_text) + # all_input_ids.pop() + # alternative_tokens = AlternativeTokens( + # request_alternative_token_ids, + # request_alternative_token_logprobs, + # request_alternative_token_texts, + # ) + # all_alternative_tokens.append(alternative_tokens) + + # Compute logprobs first as, even though we might skip the token, + # it can still be required to compute the logprobs + # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need + # this state to be stable + if request.id % self.world_size == self.rank: + # Prefill + if request_was_prefilling and request.prefill_logprobs: + out_start_index = batch.prefill_cu_outlens[i] + out_end_index = batch.prefill_cu_outlens[i + 1] + if not request_is_prefilling: + # The request is dones prefilling, meaning that we started generating new tokens + # The last logprob is a logprob for a generated token that was not part of the prompt + # We need to remove it + out_end_index -= 1 + + request_prefill_logprobs = prefill_logprobs[ + out_start_index:out_end_index + ] + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + prefill_token_ids = all_input_ids[ + cache_length + 1 : cache_length + input_length + 1 + ] - # Select top-k logprobs - request_alternative_token_ids = alternative_token_ids[token_idx][:num_alternatives] - request_alternative_token_logprobs = alternative_token_logprobs[token_idx][:num_alternatives] + past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] - # Decode tokens - request_alternative_token_texts = [] - for alternative_token_id in request_alternative_token_ids: - all_input_ids.append(alternative_token_id) - alternative_token_text, _, _ = self.decode_token( - all_input_ids, - prefix_offset, - read_offset, + if past_prefill_logprob_tokens is None: + # add nan for cached prompt tokens/first token + request_prefill_logprobs = [float("nan")] * ( + cache_length + 1 + ) + request_prefill_logprobs + prefill_token_ids = ( + all_input_ids[: cache_length + 1] + prefill_token_ids ) - request_alternative_token_texts.append(alternative_token_text) - all_input_ids.pop() - alternative_tokens = AlternativeTokens( - request_alternative_token_ids, - request_alternative_token_logprobs, - request_alternative_token_texts, + + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, ) - all_alternative_tokens.append(alternative_tokens) - stop, reason = stopping_criteria( - next_token_id, - next_token_text, - ) + prefill_logprob_tokens = NextTokens( + prefill_token_ids, + request_prefill_logprobs, + prefill_texts, + is_special=[], + ) + if past_prefill_logprob_tokens is not None: + prefill_logprob_tokens = ( + past_prefill_logprob_tokens + prefill_logprob_tokens + ) - if stop: - left = num_accepted_ids - j - 1 - current_stopped = True - break + batch.prefill_logprob_tokens[i] = prefill_logprob_tokens else: - current_stopped = False - stopped = stopped and current_stopped - - accepted_token_ids = next_token_ids[idx : idx + num_accepted_ids - left] - accepted_token_logprobs = next_token_logprobs[idx : idx + num_accepted_ids - left] - idx += num_accepted_ids - - # Shard generations - # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text = self.decode(all_input_ids[-stopping_criteria.current_tokens :]) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, + batch.prefill_logprob_tokens[i] = None + + # If it is, the tokens we decoded should be ignored + if request_is_prefilling: + # Make sure that we do not stop as even though this request did not create a token, it is still + # processing + stopped = False + new_input_length = next_chunk_lengths[i] + else: + new_input_length = n_accepted_ids + # Append next token to all tokens + next_token_texts = [] + left = 0 + + if n_accepted_ids > 1: + logger.debug(f"speculated ids {n_accepted_ids - 1}") + + current_stopped = False + for j in range(index, index + n_accepted_ids): + # Generated token + next_token_id = next_token_ids[j] + all_input_ids.append(next_token_id) + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids, + prefix_offset, + read_offset, ) - else: - generated_text = None - - # Prefill - if prefill and request.prefill_logprobs: - out_start_index = batch.prefill_cu_outlens[i] - out_end_index = batch.prefill_cu_outlens[i + 1] + next_token_texts.append(next_token_text) - # Remove generated token to only have prefill and add nan for first prompt token - request_prefill_logprobs = ([float("nan")] * (len(prefix_ids) + 1)) + prefill_logprobs[ - out_start_index : out_end_index - 1 - ] - prefill_token_ids = all_input_ids[:-1] - prefill_texts = self.tokenizer.batch_decode( - prefix_ids + prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, + stop, reason = stopping_criteria( + next_token_id, + next_token_text, ) - prefill_tokens = PrefillTokens(prefill_token_ids, request_prefill_logprobs, prefill_texts) - else: - prefill_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - len(all_input_ids[:-1]) if prefill else 0, - NextTokens( - accepted_token_ids, - accepted_token_logprobs, - next_token_texts, - [tid in self.all_special_ids for tid in accepted_token_ids], - all_alternative_tokens, - ), - generated_text, - ) - generations.append(generation) + if stop: + left = index + n_accepted_ids - j - 1 + current_stopped = True + break + else: + current_stopped = False + stopped = stopped and current_stopped + + _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] + _next_token_logprobs = next_token_logprobs[ + index : index + n_accepted_ids - left + ] + + # Shard generations + # All generations will be appended in the rust sharded client + if request.id % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( + all_input_ids, + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) + - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + generated_text = GeneratedText( + output_text, + stopping_criteria.current_tokens, + reason, + seed if do_sample else None, + ) + else: + generated_text = None + + # TODO(travis): top tokens + # if top_n_tokens > 0: + # all_top_tokens = [] + # for top_token_ids, top_token_logprobs in zip( + # top_token_ids, top_token_logprobs + # ): + # toptoken_texts = self.tokenizer.batch_decode( + # top_token_ids, + # clean_up_tokenization_spaces=False, + # skip_special_tokens=False, + # ) + # special_toptokens = [ + # token_id in self.all_special_ids + # for token_id in top_token_ids + # ] + # top_tokens = Tokens( + # top_token_ids, + # top_token_logprobs, + # toptoken_texts, + # special_toptokens, + # ) + # all_top_tokens.append(top_tokens) + # top_tokens = all_top_tokens + # else: + # top_tokens = None + + generation = Generation( + request.id, + batch.prefill_logprob_tokens[i], + NextTokens( + _next_token_ids, + _next_token_logprobs, + next_token_texts, + [nid in self.all_special_ids for nid in _next_token_ids], + ), + generated_text, + ) - # advance the FSM for each accepted token (as we may have more than one from speculative decoding) - for next_token_id in accepted_token_ids: - batch.next_token_chooser.next_state(i, next_token_id) + generations.append(generation) + + # advance the FSM for each accepted token (as we may have more than one from speculative decoding) + for next_token_id in _next_token_ids: + batch.next_token_chooser.next_state(i, next_token_id) # Update values - batch.input_lengths[i] = input_length + num_accepted_ids.item() - if batch.input_lengths[i] > batch.max_seqlen: - batch.max_seqlen = batch.input_lengths[i] + index += n_accepted_ids + current_cache_length = cache_length + input_length + batch.cache_lengths[i] = current_cache_length + current_input_length = new_input_length + batch.max_input_length = max(batch.max_input_length, current_input_length) + batch.input_lengths[i] = current_input_length + current_length = current_cache_length + current_input_length + batch.max_current_length = max(batch.max_current_length, current_length) + batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids @@ -1812,9 +2012,12 @@ def generate_token( # No need to return a batch if we know that all requests stopped return generations, None - batch.prefill_cu_outlens = None - batch.prefill_head_indices = None - batch.prefill_next_token_indices = None - batch.max_seqlen = batch.max_seqlen + 1 + if prefill and finished_prefilling: + # We do not need prefill tensors anymore + batch.cu_seqlen_prefill = None + batch.prefill_cache_indices = None + batch.prefill_cu_outlens = None + batch.prefill_head_indices = None + batch.prefill_next_token_indices = None return generations, batch From 996fd2d08ea3bbd0f7932c6ea4a30f1d93f5476e Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 17 Oct 2024 13:52:08 -0700 Subject: [PATCH 06/34] Fix prepare_for_prefill --- server/lorax_server/models/flash_causal_lm.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index c481d106d..d342b1565 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -784,12 +784,14 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch ) def prepare_for_prefill(self): + global SLIDING_WINDOW + global SLIDING_WINDOW_BLOCKS + # Prepare values if we need to continue prefilling # Speculation must be ignored while we prefill even with chunking # it simplifies everything assert self.speculative_ids is None - sliding_window = get_sliding_windows() position_ids = [] cu_seqlen_prefill = [0] slot_indices = [] @@ -853,9 +855,9 @@ def prepare_for_prefill(self): ) # Create tensor to slice into the kv tensor in prefill - if sliding_window is not None: + if SLIDING_WINDOW is not None: request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, input_length - sliding_window), + cumulative_length + max(0, input_length - SLIDING_WINDOW), cumulative_length + input_length, dtype=torch.int64, ) @@ -893,13 +895,11 @@ def prepare_for_prefill(self): slots.extend(request_slots) slot_indices.append(request_slot_indices) - if sliding_window is not None: + if SLIDING_WINDOW is not None: prefill_cache_indices.append(request_prefill_cache_indices) - ADAPTER_TO_INDEX = get_adapter_to_index() - adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) - adapter_indices_list.append(torch.full((next_chunk_length,), adapter_index)) - adapter_set.add(adapter_index) + adapter_indices_list.append(torch.full((next_chunk_length,), r.adapter_index)) + adapter_set.add(r.adapter_index) # Update cumulative_length += next_chunk_length @@ -917,12 +917,12 @@ def prepare_for_prefill(self): if len(self) > 1: position_ids = torch.cat(position_ids) slot_indices = torch.cat(slot_indices) - if sliding_window is not None: + if SLIDING_WINDOW is not None: prefill_cache_indices = torch.cat(prefill_cache_indices) else: position_ids = position_ids[0] slot_indices = slot_indices[0] - if sliding_window is not None: + if SLIDING_WINDOW is not None: prefill_cache_indices = prefill_cache_indices[0] self.prefill_cu_outlens = prefill_cu_outlens @@ -933,7 +933,7 @@ def prepare_for_prefill(self): self.position_ids = position_ids.to(device) self.slot_indices = slot_indices.to(device) self.prefill_cache_indices = ( - prefill_cache_indices.to(device) if sliding_window is not None else None + prefill_cache_indices.to(device) if SLIDING_WINDOW is not None else None ) self.input_lengths_tensor = torch.tensor( self.input_lengths, dtype=torch.int32, device=device From 582be6c755f7c1dc1da8f19bf781561633bc7e77 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 17 Oct 2024 13:58:20 -0700 Subject: [PATCH 07/34] Plumbing --- server/lorax_server/models/model.py | 23 ++++++++++++++++++++- server/lorax_server/models/vlm_causal_lm.py | 2 ++ server/lorax_server/utils/state.py | 2 ++ 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 2ab15c61a..f51087f35 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -17,7 +17,7 @@ load_and_merge_adapters, ) from lorax_server.utils.sources import HUB -from lorax_server.utils.state import BLOCK_SIZE, get_speculative_tokens +from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER, PREFILL_CHUNKING, get_speculative_tokens, set_supports_chunking from lorax_server.utils.tokenizer import TokenizerManager from lorax_server.utils.weights import shard_on_dim @@ -41,6 +41,7 @@ def __init__( dynamic_adapter_loading_enabled: bool = True, trust_remote_code: bool = False, processor=None, + supports_chunking: bool = False, ): self.model_id = model_id self.model = model.eval() @@ -73,6 +74,26 @@ def __init__( self.trust_remote_code = trust_remote_code + speculation_tokens = get_speculative_tokens() + + support_chunking = support_chunking and PREFILL_CHUNKING + if support_chunking: + if speculation_tokens != 0: + logger.warning( + "Prefill chunking does not support speculation yet. " + "Prefill chunking will be turned off", + ) + support_chunking = False + if not FLASH_INFER: + logger.warning( + "Prefill chunking is only supported with `flashinfer` backend.", + ) + support_chunking = False + logger.info(f"Using experimental prefill chunking = {support_chunking}") + + self.supports_chunking = supports_chunking + set_supports_chunking(supports_chunking) + self.has_position_ids = inspect.signature(model.forward).parameters.get("position_ids", None) is not None if dynamic_adapter_loading_enabled and adapter_id and adapter_id != BASE_MODEL_ADAPTER_ID: diff --git a/server/lorax_server/models/vlm_causal_lm.py b/server/lorax_server/models/vlm_causal_lm.py index 8b03e1b8d..133bad457 100644 --- a/server/lorax_server/models/vlm_causal_lm.py +++ b/server/lorax_server/models/vlm_causal_lm.py @@ -277,6 +277,8 @@ def __init__( adapter_source=adapter_source, processor=processor, trust_remote_code=trust_remote_code, + # FIXME: VLM do not work with context chunking yet + support_chunking=False, **kwargs, ) diff --git a/server/lorax_server/utils/state.py b/server/lorax_server/utils/state.py index 809c756a2..33ec1c0fc 100644 --- a/server/lorax_server/utils/state.py +++ b/server/lorax_server/utils/state.py @@ -17,6 +17,8 @@ if FLASH_INFER: logger.info("Using flashinfer") +PREFILL_CHUNKING = bool(os.environ.get("PREFILL_CHUNKING", "")) +logger.info(f"Prefill chunking = {PREFILL_CHUNKING}") SUPPORTS_CHUNKING: Optional[bool] = None MAX_PREFILL_TOKENS: Optional[int] = None From 445b2c248cde0ec5c88100acf75e8cd2f3a13a7a Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 17 Oct 2024 13:58:55 -0700 Subject: [PATCH 08/34] InfoResponse --- server/lorax_server/models/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index f51087f35..f7722c0fe 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -124,6 +124,7 @@ def info(self) -> InfoResponse: supports_generation=self.supports_text_generation, supports_embeddings=self.supports_embeddings, supports_classification=self.supports_classification, + supports_chunking=self.supports_chunking, ) @property From b87fe8784516caad1c08479c7effea6fa0cf2f07 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 17 Oct 2024 14:02:08 -0700 Subject: [PATCH 09/34] Seqlen for llama --- .../custom_modeling/flash_llama_modeling.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index bc21484a2..e5ccd2617 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -20,6 +20,8 @@ from typing import List, Optional, Tuple +from lorax_server.utils.attention.common import Seqlen + # Flash attention imports import dropout_layer_norm import torch @@ -300,7 +302,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -344,7 +346,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -431,7 +433,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, cross_attention_states, @@ -447,7 +449,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -500,7 +502,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor], @@ -523,7 +525,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, cross_attention_states, @@ -568,7 +570,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, @@ -583,7 +585,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, prefill_cache_indices, From b95d6edc17e9d93775e7e9c05cafdd43b459823d Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 17 Oct 2024 14:08:56 -0700 Subject: [PATCH 10/34] Rename --- proto/generate.proto | 2 +- router/src/infer.rs | 13 +++++++------ router/src/server.rs | 1 + server/lorax_server/models/model.py | 12 ++++++------ 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/proto/generate.proto b/proto/generate.proto index 9ba083efc..cc287bfe0 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -57,7 +57,7 @@ message InfoResponse { bool supports_generation = 8; bool supports_embeddings = 9; bool supports_classification = 10; - bool supports_chunking = 11; + bool prefill_chunking = 11; } /// Empty request diff --git a/router/src/infer.rs b/router/src/infer.rs index 11eeb374f..0b40e3957 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -181,6 +181,7 @@ impl Infer { speculate: u32, preloaded_adapters: Vec, prefix_caching: bool, + prefill_chunking: bool, is_causal_lm: bool, ) -> Self { let adapter_event = Arc::new(AdapterEvent { @@ -250,6 +251,7 @@ impl Infer { generation_health, adapter_scheduler.clone(), eager_prefill, + prefill_chunking, )); // Inference limit with a semaphore @@ -898,9 +900,8 @@ async fn batching_task( generation_health: Arc, adapter_scheduler: AdapterScheduler, eager_prefill: bool, + prefill_chunking: bool, ) { - let support_chunking = true; - // Infinite loop loop { // Fire if a new request comes in or an adapter becomes ready @@ -939,7 +940,7 @@ async fn batching_task( adapter_scheduler.remove_errored_adapters().await; let mut token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - let (min_size, max_size, prefill_token_budget) = if support_chunking { + let (min_size, max_size, prefill_token_budget) = if prefill_chunking { // Since the next batch will be concatenated with the current batch, // the current batch tokens must be subtracted to the prefill budget let prefill_token_budget = @@ -999,14 +1000,14 @@ async fn batching_task( if min_size.is_some() { metrics::increment_counter!("lorax_batch_concat", "reason" => "backpressure"); } else { - if support_chunking { + if prefill_chunking { metrics::increment_counter!("lorax_batch_concat", "reason" => "chunking") } else { metrics::increment_counter!("lorax_batch_concat", "reason" => "wait_exceeded") }; } - let cached_batch = if support_chunking { + let cached_batch = if prefill_chunking { // Concat current batch to the new one batches.pop() } else { @@ -1047,7 +1048,7 @@ async fn batching_task( // Extend current batch with the new batch if let Some(new_cached_batch) = new_cached_batch { batches.push(new_cached_batch); - } else if support_chunking { + } else if prefill_chunking { // New cached batch is empty, no work left break; } diff --git a/router/src/server.rs b/router/src/server.rs index 758348fdf..fb07839ca 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1401,6 +1401,7 @@ pub async fn run( shard_info.speculate, shard_info.preloaded_adapters, prefix_caching, + shard_info.prefill_chunking, is_causal_lm, ); diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index f7722c0fe..babee5d80 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -76,20 +76,20 @@ def __init__( speculation_tokens = get_speculative_tokens() - support_chunking = support_chunking and PREFILL_CHUNKING - if support_chunking: + supports_chunking = supports_chunking and PREFILL_CHUNKING + if supports_chunking: if speculation_tokens != 0: logger.warning( "Prefill chunking does not support speculation yet. " "Prefill chunking will be turned off", ) - support_chunking = False + supports_chunking = False if not FLASH_INFER: logger.warning( "Prefill chunking is only supported with `flashinfer` backend.", ) - support_chunking = False - logger.info(f"Using experimental prefill chunking = {support_chunking}") + supports_chunking = False + logger.info(f"Using experimental prefill chunking = {supports_chunking}") self.supports_chunking = supports_chunking set_supports_chunking(supports_chunking) @@ -124,7 +124,7 @@ def info(self) -> InfoResponse: supports_generation=self.supports_text_generation, supports_embeddings=self.supports_embeddings, supports_classification=self.supports_classification, - supports_chunking=self.supports_chunking, + prefill_chunking=self.supports_chunking, ) @property From 5720b3fbe9390cfd637951fd0675452266b111e7 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 17 Oct 2024 14:18:08 -0700 Subject: [PATCH 11/34] Responses are correct --- server/lorax_server/models/flash_causal_lm.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index d342b1565..7596b4bab 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -318,6 +318,7 @@ def from_pb( batch_id=pb.id, requests=pb.requests, requests_idx_mapping=requests_idx_mapping, + input_ids=all_postfix_ids, block_tables=block_tables, block_tables_tensor=block_tables_tensor, cache_lengths=cache_lengths, @@ -1218,7 +1219,7 @@ def adapter_memory_size(self) -> int: def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model: bool = False): # The warmup batch is the biggest batch we could ever receive - max_total_tokens = batch.max_seqlen + max_new_tokens + get_speculative_tokens() + max_total_tokens = batch.max_input_length + max_new_tokens + get_speculative_tokens() torch.cuda.empty_cache() try: @@ -1236,9 +1237,9 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model logger.info("Warming up to max_new_tokens: {}", max_new_tokens) with tqdm(total=max_new_tokens, desc="Warmup to max_total_tokens") as pbar: for _ in range(max_new_tokens): - cur_seqlen = batch.max_seqlen + cur_seqlen = batch.max_current_length _, batch = self.generate_token(batch, is_warmup=True) - new_seqlen = batch.max_seqlen + new_seqlen = batch.max_current_length pbar.update(new_seqlen - cur_seqlen) if new_seqlen >= max_total_tokens - get_speculative_tokens(): break @@ -1444,7 +1445,7 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> input_lengths=batch.input_lengths, input_lengths_tensor=input_lengths, cache_lens=batch.cache_lengths, - cache_lengths_tensor=cache_lengths_tensor, + cache_lens_tensor=cache_lengths_tensor, ): out = model.forward( input_ids=input_ids, @@ -1871,7 +1872,8 @@ def generate_token( prefill_token_ids, request_prefill_logprobs, prefill_texts, - is_special=[], + [], + all_alternative_tokens, ) if past_prefill_logprob_tokens is not None: prefill_logprob_tokens = ( @@ -1979,11 +1981,13 @@ def generate_token( generation = Generation( request.id, batch.prefill_logprob_tokens[i], + len(all_input_ids[:-1]) if prefill else 0, NextTokens( _next_token_ids, _next_token_logprobs, next_token_texts, [nid in self.all_special_ids for nid in _next_token_ids], + all_alternative_tokens, ), generated_text, ) From 9d75bb063b791aa5a81ce75ea6dac7f14c8bd60c Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 17 Oct 2024 14:24:34 -0700 Subject: [PATCH 12/34] Fixed compile --- server/lorax_server/models/flash_causal_lm.py | 2 +- server/lorax_server/utils/graph.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 7596b4bab..34789ec38 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1471,7 +1471,7 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> slots=slots, seqlen=seqlen, cache_lens=batch.cache_lengths, - cache_lengths_tensor=cache_lengths_tensor, + cache_lens_tensor=cache_lengths_tensor, max_s=max_s, adapter_data=adapter_data, prefill_cache_indices=batch.prefill_cache_indices, diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index c16515d2e..1605c3250 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -148,7 +148,8 @@ def get_max_graph_state( max_q=1, max_k=max_total_tokens, ), - cache_lengths_tensor=cache_lengths_tensor, + cache_lens=cache_lengths, + cache_lens_tensor=cache_lengths_tensor, adapter_data=AdapterBatchData( meta=AdapterBatchMetadata( adapter_indices=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), @@ -240,7 +241,7 @@ def trace( } block_tables = max_input_state.block_tables[:batch_size] - input_lengths = max_input_state.input_lengths[:batch_size] + input_lengths = max_input_state.seqlen.input_lengths[:batch_size] cache_lengths = max_input_state.cache_lens[:batch_size] cache_lengths_tensor = max_input_state.cache_lens_tensor[:batch_size] state = None @@ -401,8 +402,8 @@ def forward( with self.forward_context( block_tables=self.input_state.block_tables, cu_seqlen_prefill=None, - input_lengths=input_lengths, - input_lengths_tensor=self.input_state.input_lengths, + input_lengths=seqlen.input_lengths, + input_lengths_tensor=self.input_state.seqlen.input_lengths, cache_lens=self.input_state.cache_lens, cache_lens_tensor=self.input_state.cache_lens_tensor, state=self.input_state.state, @@ -452,7 +453,7 @@ def can_use_graph( max_rank = max(ranks) if len(ranks) > 0 else 0 batch_size = batch.input_ids.shape[0] - max_s = batch.max_seqlen + max_s = batch.max_current_length # TODO(travis): allow using CUDA graphs with multi-rank batches return ( From 8d413db7710bc02eb397c06fb40de22b11d3cf22 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 17 Oct 2024 15:07:29 -0700 Subject: [PATCH 13/34] Fix flashinfer --- server/lorax_server/models/flash_causal_lm.py | 13 ++++---- server/lorax_server/models/mllama.py | 2 +- server/lorax_server/models/vlm_causal_lm.py | 2 +- server/lorax_server/utils/flash_attn.py | 2 +- .../utils/flashinfer_attention.py | 33 ++++++++++++++----- server/lorax_server/utils/graph.py | 4 +-- 6 files changed, 35 insertions(+), 21 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 34789ec38..66fb9e185 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1348,29 +1348,28 @@ def _forward_context( if cu_seqlen_prefill is not None: return use_prefill_with_paged_kv_state( state=(state if state is not None else self.prefill_with_paged_kv_state), - # block_tables=block_tables_to_ragged( - # block_tables=block_tables, - # input_lengths=input_lengths, - # prefix_lens=prefix_lens, - # ), block_tables=block_tables, cu_seqlens=cu_seqlen_prefill, - input_lengths=input_lengths_tensor, + input_lengths=input_lengths_tensor + cache_lens_tensor, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_size, page_size=BLOCK_SIZE, + dtype=self.dtype, + window_left=self.sliding_window, ) else: assert input_lengths_tensor is not None return use_decode_state( state=state if state is not None else self.decode_state, - input_lengths=input_lengths_tensor, + input_lengths=input_lengths_tensor + cache_lens_tensor, block_tables=block_tables, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_size, page_size=BLOCK_SIZE, + dtype=self.dtype, + window_left=self.sliding_window, ) def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/server/lorax_server/models/mllama.py b/server/lorax_server/models/mllama.py index c7861a015..342570094 100644 --- a/server/lorax_server/models/mllama.py +++ b/server/lorax_server/models/mllama.py @@ -226,7 +226,7 @@ def forward( block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + cache_lengths=batch.prefix_lens, ) with self._forward_context( block_tables=block_tables, diff --git a/server/lorax_server/models/vlm_causal_lm.py b/server/lorax_server/models/vlm_causal_lm.py index 133bad457..ce2d9b457 100644 --- a/server/lorax_server/models/vlm_causal_lm.py +++ b/server/lorax_server/models/vlm_causal_lm.py @@ -355,7 +355,7 @@ def forward( block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + cache_lengths=batch.prefix_lens, ) with self._forward_context( block_tables=block_tables, diff --git a/server/lorax_server/utils/flash_attn.py b/server/lorax_server/utils/flash_attn.py index 404cec60c..560ddadc4 100644 --- a/server/lorax_server/utils/flash_attn.py +++ b/server/lorax_server/utils/flash_attn.py @@ -128,7 +128,6 @@ def attention( causal=True, softcap=0.0, ): - assert window_size_left == -1, "Windowing is not supported with flash infer when using kv cache" from lorax_server.utils.flashinfer_attention import prefill_state, prefill_with_paged_kv_state if key_cache is None or value_cache is None: @@ -149,6 +148,7 @@ def attention( paged_kv_cache=(key_cache, value_cache), logits_soft_cap=softcap, sm_scale=softmax_scale, + window_left=window_size_left, ) elif HAS_FLASH_ATTN_V2_CUDA: diff --git a/server/lorax_server/utils/flashinfer_attention.py b/server/lorax_server/utils/flashinfer_attention.py index 1ffea21a7..c8f8d4c00 100644 --- a/server/lorax_server/utils/flashinfer_attention.py +++ b/server/lorax_server/utils/flashinfer_attention.py @@ -44,7 +44,8 @@ def use_prefill_with_paged_kv_state( num_kv_heads: int, head_size: int, page_size: int, - query_dtype: str = "float16", + dtype: torch.dtype, + window_left: int, ): """ Context manager to set the active flashinfer prefill state to the given @@ -52,7 +53,9 @@ def use_prefill_with_paged_kv_state( `attention` function while the context manager is active. """ - indptr = torch.zeros(input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32) + indptr = torch.zeros( + input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 + ) # Round up to page size and then calculate the cumulative sum to get # the indices into the block table. torch.add(input_lengths, page_size - 1, out=indptr[1:]) @@ -61,9 +64,13 @@ def use_prefill_with_paged_kv_state( # Get the lengths of the last page in a block. if page_size == 1: - last_page_len = torch.ones(input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device) + last_page_len = torch.ones( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) else: - last_page_len = torch.empty(input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device) + last_page_len = torch.empty( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) torch.sub(input_lengths, 1, out=last_page_len) last_page_len.remainder_(page_size) last_page_len += 1 @@ -78,8 +85,9 @@ def use_prefill_with_paged_kv_state( num_qo_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_size, - q_data_type=query_dtype, + q_data_type=dtype, page_size=page_size, + # window_left=window_left, # TODO ) yield finally: @@ -182,14 +190,17 @@ def use_decode_state( num_kv_heads: int, head_size: int, page_size: int, - query_dtype: str = "float16", + dtype: torch.dtype, + window_left: int, ): """ Context manager to set the active flashinfer decoding state to the given `state` and parameters. This state will be used by all calls to the `paged_attention` function while the context manager is active. """ - indptr = torch.zeros(input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32) + indptr = torch.zeros( + input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 + ) # Round up to page size and then calculate the cumulative sum to get # the indices into the block table. torch.add(input_lengths, page_size - 1, out=indptr[1:]) @@ -197,7 +208,9 @@ def use_decode_state( indptr[1:].cumsum_(-1) # Get the lengths of the last page in a block. - last_page_len = torch.empty(input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device) + last_page_len = torch.empty( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) torch.sub(input_lengths, 1, out=last_page_len) last_page_len.remainder_(page_size) last_page_len += 1 @@ -213,7 +226,9 @@ def use_decode_state( num_kv_heads=num_kv_heads, head_dim=head_size, page_size=page_size, - q_data_type=query_dtype, + data_type=dtype, + q_data_type=dtype, + # window_left=window_left, TODO ) yield finally: diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index 1605c3250..da10d7cce 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -254,7 +254,7 @@ def trace( block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=input_lengths.tolist(), - cache_lens=cache_lengths, + cache_lengths=cache_lengths, ) block_tables_ptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) @@ -370,7 +370,7 @@ def forward( block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=seqlen.input_lengths, - cache_lens=seqlen.cache_lengths, + cache_lengths=seqlen.cache_lengths, ) self.input_state.block_tables[: block_tables.shape[0]] = block_tables else: From 053c3e5d8ebee01ba227692ea7717baa88da864f Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 17 Oct 2024 15:36:58 -0700 Subject: [PATCH 14/34] Fix prefill --- server/lorax_server/server.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index ce8171dec..c1da77fed 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -24,7 +24,7 @@ is_base_model, ) from lorax_server.utils.sgmv import has_sgmv -from lorax_server.utils.state import set_speculative_tokens +from lorax_server.utils.state import set_max_prefill_tokens, set_speculative_tokens class LoraxService(generate_pb2_grpc.LoraxServiceServicer): @@ -74,6 +74,8 @@ async def FilterBatch(self, request, context): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request: generate_pb2.WarmupRequest, context): + set_max_prefill_tokens(request.max_prefill_tokens) + batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, @@ -98,6 +100,15 @@ async def Prefill(self, request: generate_pb2.PrefillRequest, context): self.model.device, ) + if self.model.support_chunking: + if request.HasField("cached_batch"): + cached_batch = self.cache.pop(request.cached_batch.id) + if cached_batch is None: + raise ValueError( + f"Batch ID {request.cached_batch.id} not found in cache." + ) + batch = self.model.batch_type.concatenate([cached_batch, batch]) + generations, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) From 2834f5ccaed8c9c5d8437006aed2d279252eb017 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 17 Oct 2024 15:47:18 -0700 Subject: [PATCH 15/34] Fix chunking --- proto/generate.proto | 6 +++--- router/client/src/client.rs | 2 ++ server/lorax_server/server.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/proto/generate.proto b/proto/generate.proto index cc287bfe0..6e3dae122 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -363,9 +363,9 @@ message ClassifyResponse { message WarmupRequest { /// Batch to warmup on Batch batch = 1; - - /// Maximum number of new tokens to warmup - uint32 max_new_tokens = 2; + uint32 max_input_length = 2; + uint32 max_prefill_tokens = 3; + uint32 max_new_tokens = 4; } /// Empty response diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 7007a26aa..0751347da 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -181,6 +181,8 @@ impl Client { let max_new_tokens = max_total_tokens - max_input_length; let request = tonic::Request::new(WarmupRequest { batch: Some(batch), + max_input_length, + max_prefill_tokens, max_new_tokens, }) .inject_context(); diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index c1da77fed..e1d2139eb 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -100,7 +100,7 @@ async def Prefill(self, request: generate_pb2.PrefillRequest, context): self.model.device, ) - if self.model.support_chunking: + if self.model.supports_chunking: if request.HasField("cached_batch"): cached_batch = self.cache.pop(request.cached_batch.id) if cached_batch is None: From 9a45299057a54e4f358394e77aaf4648340241c6 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 17 Oct 2024 15:49:35 -0700 Subject: [PATCH 16/34] Docker build --- .github/workflows/build.yaml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 13b9e96ca..07a7d265a 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -5,6 +5,7 @@ on: push: branches: - 'main' + - 'chunked-prefill' tags: - 'v*' @@ -69,10 +70,7 @@ jobs: images: | ghcr.io/predibase/lorax tags: | - type=semver,pattern={{version}} - type=semver,pattern={{major}}.{{minor}} - type=sha,prefix=,suffix=,format=short - type=raw,value=main,enable=${{ github.ref == 'refs/heads/main' }} + type=raw,value=chunked-prefill,enable=${{ github.ref == 'refs/heads/chunked-prefill' }} - name: Create a hash from tags env: From 147e8f9687a97853a117db2e100da43bf74d4f2b Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 17 Oct 2024 15:51:17 -0700 Subject: [PATCH 17/34] Fix id --- proto/generate.proto | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/proto/generate.proto b/proto/generate.proto index 6e3dae122..f03fb7d2b 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -159,10 +159,10 @@ message Request { repeated uint32 slots = 10; /// Tokens that can be retrieved from the KV cache. /// This value is set for the first prefill and never reset - uint32 cache_len = 12; + uint32 cache_len = 11; /// Chunk of tokens that must be computed for the first prefill /// This value is set for the first prefill and never reset - optional uint32 chunk_len = 13; + optional uint32 chunk_len = 12; } message Batch { From 8a69aafad0741fa4e9bf5252b178a1322d56c9cd Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 17 Oct 2024 16:43:19 -0700 Subject: [PATCH 18/34] Added missing file --- server/lorax_server/utils/attention/common.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 server/lorax_server/utils/attention/common.py diff --git a/server/lorax_server/utils/attention/common.py b/server/lorax_server/utils/attention/common.py new file mode 100644 index 000000000..61b7cc502 --- /dev/null +++ b/server/lorax_server/utils/attention/common.py @@ -0,0 +1,53 @@ +from dataclasses import dataclass +import torch +from typing import Optional + + +@dataclass +class Seqlen: + input_lengths: torch.Tensor + cache_lengths: torch.Tensor + cu_seqlen_q: Optional[torch.Tensor] + cu_seqlen_k: Optional[torch.Tensor] + max_q: int + max_k: int + + def __init__( + self, + input_lengths, + cache_lengths, + cu_seqlen_q=None, + max_q=None, + max_k=None, + ): + self.input_lengths = input_lengths + self.cache_lengths = cache_lengths + device = self.input_lengths.device + shape = self.input_lengths.shape + if cu_seqlen_q is None: + cu_seqlen_q = torch.arange( + shape[0] + 1, + device=device, + dtype=torch.int32, + ) + max_q = 1 + else: + assert max_q is not None + assert max_k is not None + cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) + + # cuda graphs don't like this and this is necessary to clamp within mistral + # Although FA2 might not want the clamping + # cu_seqlen_k[0] = 0 + total = self.input_lengths + self.cache_lengths + torch.cumsum(total, -1, out=cu_seqlen_k[1:]) + + self.cu_seqlen_q = cu_seqlen_q + self.cu_seqlen_k = cu_seqlen_k + self.max_q = max_q + self.max_k = max_k + + def clamp(self, max): + self.input_lengths = torch.clamp(self.input_lengths, max=max) + return self + \ No newline at end of file From 5467678a48a65259d8ce8ec9059f743cf754909a Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 17 Oct 2024 16:43:37 -0700 Subject: [PATCH 19/34] Docker --- .github/workflows/build.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 07a7d265a..c1db5f731 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -70,7 +70,7 @@ jobs: images: | ghcr.io/predibase/lorax tags: | - type=raw,value=chunked-prefill,enable=${{ github.ref == 'refs/heads/chunked-prefill' }} + type=raw,value=chunked-prefill-2,enable=${{ github.ref == 'refs/heads/chunked-prefill' }} - name: Create a hash from tags env: From 80f47ce0ab42bfea177b0d36a8c4595a9852aae3 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 17 Oct 2024 21:18:59 -0700 Subject: [PATCH 20/34] Fix --- server/lorax_server/models/flash_causal_lm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 66fb9e185..40bcde9c2 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -519,6 +519,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": position_ids=position_ids, speculative_ids=speculative_ids, cu_seqlen_prefill=None, + prefill_cache_indices=None, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, @@ -755,6 +756,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch position_ids=position_ids, speculative_ids=speculative_ids, cu_seqlen_prefill=None, + prefill_cache_indices=None, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, From a07b9b1cf195ea6f2613d574ae06acf91a79f7f3 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 18 Oct 2024 09:07:23 -0700 Subject: [PATCH 21/34] Warnings --- router/src/batch.rs | 4 ++-- router/src/infer.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/router/src/batch.rs b/router/src/batch.rs index e908c899c..50b91550b 100644 --- a/router/src/batch.rs +++ b/router/src/batch.rs @@ -519,7 +519,7 @@ impl BatchEntries for EmbedBatchEntries { &mut self, client: &mut ShardedClient, batch: Batch, - cached_batch: Option, + _cached_batch: Option, span: Span, generation_health: &Arc, ) -> Option { @@ -644,7 +644,7 @@ impl BatchEntries for ClassifyBatchEntries { &mut self, client: &mut ShardedClient, batch: Batch, - cached_batch: Option, + _cached_batch: Option, span: Span, generation_health: &Arc, ) -> Option { diff --git a/router/src/infer.rs b/router/src/infer.rs index 0b40e3957..bef99b000 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -940,7 +940,7 @@ async fn batching_task( adapter_scheduler.remove_errored_adapters().await; let mut token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - let (min_size, max_size, prefill_token_budget) = if prefill_chunking { + let (min_size, _max_size, prefill_token_budget) = if prefill_chunking { // Since the next batch will be concatenated with the current batch, // the current batch tokens must be subtracted to the prefill budget let prefill_token_budget = From 982fd523fbdd06924891d91bf277af847f4da348 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 18 Oct 2024 09:45:12 -0700 Subject: [PATCH 22/34] Fix flashinfer graph retrace --- server/lorax_server/utils/graph.py | 10 +++++++--- server/lorax_server/utils/state.py | 9 +++++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index da10d7cce..8c1fb2170 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -84,6 +84,7 @@ class GraphState: block_tables: torch.Tensor slots: torch.Tensor seqlen: Seqlen + input_lengths: List[int] cache_lens: List[int] cache_lens_tensor: torch.Tensor adapter_data: AdapterBatchData @@ -148,6 +149,7 @@ def get_max_graph_state( max_q=1, max_k=max_total_tokens, ), + input_lengths=input_lengths.tolist(), cache_lens=cache_lengths, cache_lens_tensor=cache_lengths_tensor, adapter_data=AdapterBatchData( @@ -241,7 +243,8 @@ def trace( } block_tables = max_input_state.block_tables[:batch_size] - input_lengths = max_input_state.seqlen.input_lengths[:batch_size] + input_lengths = max_input_state.input_lengths[:batch_size] + input_lengths_tensor = max_input_state.seqlen.input_lengths[:batch_size] cache_lengths = max_input_state.cache_lens[:batch_size] cache_lengths_tensor = max_input_state.cache_lens_tensor[:batch_size] state = None @@ -253,7 +256,7 @@ def trace( block_tables = block_tables_to_ragged( block_tables=block_tables, - input_lengths=input_lengths.tolist(), + input_lengths=input_lengths, cache_lengths=cache_lengths, ) @@ -274,12 +277,13 @@ def trace( block_tables=block_tables, slots=max_input_state.slots[:batch_size], seqlen=Seqlen( - input_lengths=input_lengths, + input_lengths=input_lengths_tensor, cache_lengths=cache_lengths_tensor, cu_seqlen_q=None, max_q=1, max_k=max_total_tokens, ), + input_lengths=input_lengths, cache_lens=cache_lengths, cache_lens_tensor=cache_lengths_tensor, adapter_data=AdapterBatchData( diff --git a/server/lorax_server/utils/state.py b/server/lorax_server/utils/state.py index 33ec1c0fc..f60f96b79 100644 --- a/server/lorax_server/utils/state.py +++ b/server/lorax_server/utils/state.py @@ -9,15 +9,16 @@ PREFIX_CACHING = bool(os.environ.get("PREFIX_CACHING", "")) -logger.info(f"Prefix caching = {PREFIX_CACHING}") - +PREFILL_CHUNKING = bool(os.environ.get("PREFILL_CHUNKING", "")) # Always use flashinfer when prefix caching is enabled FLASH_INFER = bool(os.environ.get("FLASH_INFER", "")) or PREFIX_CACHING if FLASH_INFER: - logger.info("Using flashinfer") + logger.info("Backend = flashinfer") +else: + logger.info("Backend = fa2") -PREFILL_CHUNKING = bool(os.environ.get("PREFILL_CHUNKING", "")) +logger.info(f"Prefix caching = {PREFIX_CACHING}") logger.info(f"Prefill chunking = {PREFILL_CHUNKING}") SUPPORTS_CHUNKING: Optional[bool] = None From 2621897498be1a47d4f719d1522c422b9ccf3a36 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 18 Oct 2024 09:53:14 -0700 Subject: [PATCH 23/34] In-place softmax --- server/lorax_server/models/flash_causal_lm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 40bcde9c2..96a7cb62a 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1691,8 +1691,9 @@ def generate_token( batch.adapter_meta.adapter_indices = next_adapter_indices if prefill and prefill_logprobs: - # Get prefill logprobs - prefill_logprobs_tensor = torch.log_softmax(out, -1) + # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size)) + torch.log_softmax(out, -1, out=out) + prefill_logprobs_tensor = out prefill_logprobs = torch.gather(prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)) # GPU <-> CPU sync prefill_logprobs = prefill_logprobs.view(-1).tolist() From e0b017772574a749933aae9472b66fff49d54548 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 18 Oct 2024 10:08:01 -0700 Subject: [PATCH 24/34] Fix concatenate --- server/lorax_server/models/flash_causal_lm.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 96a7cb62a..dafda244a 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -475,6 +475,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": max_blocks = max(max_blocks, len(request_block_table)) all_input_ids_tensor = self.all_input_ids_tensor[indices] + block_tables_tensor = self.block_tables_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None @@ -571,17 +572,17 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch # `total_slots` is not used if any of the batches is prefilling total_slots += len(b.slots) if not b.prefilling else 0 num_blocks += b.num_blocks + speculative_length = ( + b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 + ) max_input_length = max(max_input_length, b.max_input_length) max_current_length = max(max_current_length, b.max_current_length) - - speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 max_length = max( max_length, max( prompt_length + stopping_criteria.max_new_tokens + speculative_length - - stopping_criteria.current_tokens for prompt_length, stopping_criteria in zip( b.prompt_lengths, b.stopping_criterias ) From 46b06b7797c6312553ffcbb37bc71e552595ce68 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 18 Oct 2024 10:43:45 -0700 Subject: [PATCH 25/34] Rename prefill chunking -> chunked prefill --- proto/generate.proto | 2 +- router/src/infer.rs | 14 +++++++------- router/src/server.rs | 2 +- server/lorax_server/models/model.py | 14 +++++++------- server/lorax_server/utils/state.py | 4 ++-- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/proto/generate.proto b/proto/generate.proto index f03fb7d2b..a44bcdb54 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -57,7 +57,7 @@ message InfoResponse { bool supports_generation = 8; bool supports_embeddings = 9; bool supports_classification = 10; - bool prefill_chunking = 11; + bool chunked_prefill = 11; } /// Empty request diff --git a/router/src/infer.rs b/router/src/infer.rs index bef99b000..ba57d206d 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -181,7 +181,7 @@ impl Infer { speculate: u32, preloaded_adapters: Vec, prefix_caching: bool, - prefill_chunking: bool, + chunked_prefill: bool, is_causal_lm: bool, ) -> Self { let adapter_event = Arc::new(AdapterEvent { @@ -251,7 +251,7 @@ impl Infer { generation_health, adapter_scheduler.clone(), eager_prefill, - prefill_chunking, + chunked_prefill, )); // Inference limit with a semaphore @@ -900,7 +900,7 @@ async fn batching_task( generation_health: Arc, adapter_scheduler: AdapterScheduler, eager_prefill: bool, - prefill_chunking: bool, + chunked_prefill: bool, ) { // Infinite loop loop { @@ -940,7 +940,7 @@ async fn batching_task( adapter_scheduler.remove_errored_adapters().await; let mut token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - let (min_size, _max_size, prefill_token_budget) = if prefill_chunking { + let (min_size, _max_size, prefill_token_budget) = if chunked_prefill { // Since the next batch will be concatenated with the current batch, // the current batch tokens must be subtracted to the prefill budget let prefill_token_budget = @@ -1000,14 +1000,14 @@ async fn batching_task( if min_size.is_some() { metrics::increment_counter!("lorax_batch_concat", "reason" => "backpressure"); } else { - if prefill_chunking { + if chunked_prefill { metrics::increment_counter!("lorax_batch_concat", "reason" => "chunking") } else { metrics::increment_counter!("lorax_batch_concat", "reason" => "wait_exceeded") }; } - let cached_batch = if prefill_chunking { + let cached_batch = if chunked_prefill { // Concat current batch to the new one batches.pop() } else { @@ -1048,7 +1048,7 @@ async fn batching_task( // Extend current batch with the new batch if let Some(new_cached_batch) = new_cached_batch { batches.push(new_cached_batch); - } else if prefill_chunking { + } else if chunked_prefill { // New cached batch is empty, no work left break; } diff --git a/router/src/server.rs b/router/src/server.rs index fb07839ca..e2806dd87 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1401,7 +1401,7 @@ pub async fn run( shard_info.speculate, shard_info.preloaded_adapters, prefix_caching, - shard_info.prefill_chunking, + shard_info.chunked_prefill, is_causal_lm, ); diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index babee5d80..ec2da1a2e 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -17,7 +17,7 @@ load_and_merge_adapters, ) from lorax_server.utils.sources import HUB -from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER, PREFILL_CHUNKING, get_speculative_tokens, set_supports_chunking +from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER, CHUNKED_PREFILL, get_speculative_tokens, set_supports_chunking from lorax_server.utils.tokenizer import TokenizerManager from lorax_server.utils.weights import shard_on_dim @@ -76,20 +76,20 @@ def __init__( speculation_tokens = get_speculative_tokens() - supports_chunking = supports_chunking and PREFILL_CHUNKING + supports_chunking = supports_chunking and CHUNKED_PREFILL if supports_chunking: if speculation_tokens != 0: logger.warning( - "Prefill chunking does not support speculation yet. " - "Prefill chunking will be turned off", + "Chunked prefill does not support speculation yet. " + "Chunked prefill will be disabled", ) supports_chunking = False if not FLASH_INFER: logger.warning( - "Prefill chunking is only supported with `flashinfer` backend.", + "Chunked prefill is only supported with `flashinfer` backend.", ) supports_chunking = False - logger.info(f"Using experimental prefill chunking = {supports_chunking}") + logger.info(f"Using experimental chunked prefill = {supports_chunking}") self.supports_chunking = supports_chunking set_supports_chunking(supports_chunking) @@ -124,7 +124,7 @@ def info(self) -> InfoResponse: supports_generation=self.supports_text_generation, supports_embeddings=self.supports_embeddings, supports_classification=self.supports_classification, - prefill_chunking=self.supports_chunking, + chunked_prefill=self.supports_chunking, ) @property diff --git a/server/lorax_server/utils/state.py b/server/lorax_server/utils/state.py index f60f96b79..130cdc6e5 100644 --- a/server/lorax_server/utils/state.py +++ b/server/lorax_server/utils/state.py @@ -9,7 +9,7 @@ PREFIX_CACHING = bool(os.environ.get("PREFIX_CACHING", "")) -PREFILL_CHUNKING = bool(os.environ.get("PREFILL_CHUNKING", "")) +CHUNKED_PREFILL = bool(os.environ.get("CHUNKED_PREFILL", "")) # Always use flashinfer when prefix caching is enabled FLASH_INFER = bool(os.environ.get("FLASH_INFER", "")) or PREFIX_CACHING @@ -19,7 +19,7 @@ logger.info("Backend = fa2") logger.info(f"Prefix caching = {PREFIX_CACHING}") -logger.info(f"Prefill chunking = {PREFILL_CHUNKING}") +logger.info(f"Chunked prefill = {CHUNKED_PREFILL}") SUPPORTS_CHUNKING: Optional[bool] = None MAX_PREFILL_TOKENS: Optional[int] = None From 0fc2641bc0b06a4446e86880feb21a83251ea3a6 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 18 Oct 2024 10:49:24 -0700 Subject: [PATCH 26/34] Added launcher args --- launcher/src/main.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index e6219500d..27ac09800 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -329,6 +329,11 @@ struct Args { #[clap(long, env)] eager_prefill: Option, + /// Split prefill requests into multiple chunks and batch them with decode requests. For high QPS scenarios, this + /// can greatly improve throughput by overlapping request types. See: https://arxiv.org/pdf/2308.16369. + #[clap(long, env)] + chunked_prefill: Option, + /// Whether to use the prefix caching mechanism. This will skip computing attention on previously cached prefixes /// in the prompt. Useful in cases where many queries need to be run over a shared context, or for long multi-turn /// chats conversations. @@ -496,6 +501,7 @@ fn shard_manager( cuda_memory_fraction: f32, adapter_memory_fraction: f32, prefix_caching: Option, + chunked_prefill: Option, merge_adapter_weights: bool, backend: Backend, otlp_endpoint: Option, @@ -639,6 +645,11 @@ fn shard_manager( envs.push(("PREFIX_CACHING".into(), prefix_caching.to_string().into())); } + // Chunked prefill + if let Some(chunked_prefill) = chunked_prefill { + envs.push(("CHUNKED_PREFILL".into(), chunked_prefill.to_string().into())); + } + // Backend if backend == Backend::FlashInfer { envs.push(("FLASH_INFER".into(), "1".into())); @@ -1093,6 +1104,7 @@ fn spawn_shards( let cuda_memory_fraction = args.cuda_memory_fraction; let adapter_memory_fraction = args.adapter_memory_fraction; let prefix_caching = args.prefix_caching; + let chunked_prefill = args.chunked_prefill; let merge_adapter_weights = args.merge_adapter_weights; let backend = args.backend; let embedding_dim = args.embedding_dim; @@ -1125,6 +1137,7 @@ fn spawn_shards( cuda_memory_fraction, adapter_memory_fraction, prefix_caching, + chunked_prefill, merge_adapter_weights, backend, otlp_endpoint, From 536ab5fee6f554c1b0e8952375d90be733d748e7 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 18 Oct 2024 13:48:40 -0700 Subject: [PATCH 27/34] Revert docker --- .github/workflows/build.yaml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index c1db5f731..13b9e96ca 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -5,7 +5,6 @@ on: push: branches: - 'main' - - 'chunked-prefill' tags: - 'v*' @@ -70,7 +69,10 @@ jobs: images: | ghcr.io/predibase/lorax tags: | - type=raw,value=chunked-prefill-2,enable=${{ github.ref == 'refs/heads/chunked-prefill' }} + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=sha,prefix=,suffix=,format=short + type=raw,value=main,enable=${{ github.ref == 'refs/heads/main' }} - name: Create a hash from tags env: From 646463ebc721de2cd270c893e783905105ff90cd Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 18 Oct 2024 15:23:24 -0700 Subject: [PATCH 28/34] Fix vlms --- server/lorax_server/models/flash_causal_lm.py | 17 ++++---- server/lorax_server/models/mllama.py | 14 +++---- server/lorax_server/models/vlm_causal_lm.py | 18 ++++----- server/lorax_server/utils/graph.py | 40 +++++++++---------- 4 files changed, 44 insertions(+), 45 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index dafda244a..0e6a48d9c 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1335,8 +1335,8 @@ def _forward_context( cu_seqlen_prefill: Optional[torch.Tensor], input_lengths: List[int], input_lengths_tensor: torch.Tensor, - cache_lens: List[int], - cache_lens_tensor: torch.Tensor, + cache_lengths: List[int], + cache_lengths_tensor: torch.Tensor, state: Optional[Any] = None, ) -> ContextManager: if not FLASH_INFER: @@ -1347,13 +1347,12 @@ def _forward_context( use_prefill_with_paged_kv_state, ) - # has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens) if cu_seqlen_prefill is not None: return use_prefill_with_paged_kv_state( state=(state if state is not None else self.prefill_with_paged_kv_state), block_tables=block_tables, cu_seqlens=cu_seqlen_prefill, - input_lengths=input_lengths_tensor + cache_lens_tensor, + input_lengths=input_lengths_tensor + cache_lengths_tensor, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_size, @@ -1365,7 +1364,7 @@ def _forward_context( assert input_lengths_tensor is not None return use_decode_state( state=state if state is not None else self.decode_state, - input_lengths=input_lengths_tensor + cache_lens_tensor, + input_lengths=input_lengths_tensor + cache_lengths_tensor, block_tables=block_tables, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, @@ -1446,8 +1445,8 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> cu_seqlen_prefill=batch.cu_seqlen_prefill, input_lengths=batch.input_lengths, input_lengths_tensor=input_lengths, - cache_lens=batch.cache_lengths, - cache_lens_tensor=cache_lengths_tensor, + cache_lengths=batch.cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, ): out = model.forward( input_ids=input_ids, @@ -1472,8 +1471,8 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> block_tables=block_tables, slots=slots, seqlen=seqlen, - cache_lens=batch.cache_lengths, - cache_lens_tensor=cache_lengths_tensor, + cache_lengths=batch.cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, max_s=max_s, adapter_data=adapter_data, prefill_cache_indices=batch.prefill_cache_indices, diff --git a/server/lorax_server/models/mllama.py b/server/lorax_server/models/mllama.py index 342570094..aef0f0e43 100644 --- a/server/lorax_server/models/mllama.py +++ b/server/lorax_server/models/mllama.py @@ -194,7 +194,7 @@ def forward( new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - prefix_lens_tensor = (batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)).reshape(-1) + cache_lengths_tensor = (batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)).reshape(-1) # Add Copy the block tables for all members block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B * new_length, -1).contiguous() @@ -210,7 +210,7 @@ def forward( block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - prefix_lens_tensor = batch.prefix_lens_tensor + cache_lengths_tensor = batch.cache_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -221,23 +221,23 @@ def forward( max_s = min(self.max_past(), max_s) # TODO: cuda graph - input_lengths = input_lengths + prefix_lens_tensor + input_lengths = input_lengths + cache_lengths_tensor if PREFIX_CACHING: block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths, - cache_lengths=batch.prefix_lens, + cache_lengths=batch.cache_lengths, ) with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, input_lengths=batch.input_lengths, input_lengths_tensor=input_lengths, - prefix_lens=batch.prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + cache_lengths=batch.cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, ): # TODO(travis): is this needed? - # max_k = (input_lengths + prefix_lens_tensor).max().item() + # max_k = (input_lengths + cache_lengths_tensor).max().item() if batch.pixel_values is not None: cross_attention_states = self.model.vision_forward( pixel_values=batch.pixel_values, diff --git a/server/lorax_server/models/vlm_causal_lm.py b/server/lorax_server/models/vlm_causal_lm.py index ce2d9b457..734454e82 100644 --- a/server/lorax_server/models/vlm_causal_lm.py +++ b/server/lorax_server/models/vlm_causal_lm.py @@ -278,7 +278,7 @@ def __init__( processor=processor, trust_remote_code=trust_remote_code, # FIXME: VLM do not work with context chunking yet - support_chunking=False, + supports_chunking=False, **kwargs, ) @@ -314,7 +314,7 @@ def forward( new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - prefix_lens_tensor = (batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)).reshape(-1) + cache_lengths_tensor = (batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)).reshape(-1) # Add Copy the block tables for all members block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B * new_length, -1).contiguous() @@ -329,7 +329,7 @@ def forward( block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - prefix_lens_tensor = batch.prefix_lens_tensor + cache_lengths_tensor = batch.cache_lengths_tensor max_s = batch.max_seqlen if cu_seqlen_prefill is None and self.max_past() is not None: @@ -350,20 +350,20 @@ def forward( model = self.model_graph_wrapper if not use_graph: - input_lengths = input_lengths + prefix_lens_tensor + input_lengths = input_lengths + cache_lengths_tensor if PREFIX_CACHING: block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths, - cache_lengths=batch.prefix_lens, + cache_lengths=batch.cache_lengths, ) with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, input_lengths=batch.input_lengths, input_lengths_tensor=input_lengths, - prefix_lens=batch.prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + cache_lengths=batch.cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, ): # input_lengths = Seqlen(input_lengths=input_lengths) out = model.forward( @@ -392,8 +392,8 @@ def forward( block_tables=block_tables, slots=slots, input_lengths=input_lengths, - prefix_lens=batch.prefix_lens, - prefix_lens_tensor=prefix_lens_tensor, + cache_lengths=batch.cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, max_s=max_s, adapter_data=adapter_data, prefill_cache_indices=batch.prefill_cache_indices, diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index 8c1fb2170..dd74f6b7f 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -85,8 +85,8 @@ class GraphState: slots: torch.Tensor seqlen: Seqlen input_lengths: List[int] - cache_lens: List[int] - cache_lens_tensor: torch.Tensor + cache_lengths: List[int] + cache_lengths_tensor: torch.Tensor adapter_data: AdapterBatchData traced_adapter_layer_names: Set[str] state: Any = None @@ -150,8 +150,8 @@ def get_max_graph_state( max_k=max_total_tokens, ), input_lengths=input_lengths.tolist(), - cache_lens=cache_lengths, - cache_lens_tensor=cache_lengths_tensor, + cache_lengths=cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, adapter_data=AdapterBatchData( meta=AdapterBatchMetadata( adapter_indices=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), @@ -245,8 +245,8 @@ def trace( block_tables = max_input_state.block_tables[:batch_size] input_lengths = max_input_state.input_lengths[:batch_size] input_lengths_tensor = max_input_state.seqlen.input_lengths[:batch_size] - cache_lengths = max_input_state.cache_lens[:batch_size] - cache_lengths_tensor = max_input_state.cache_lens_tensor[:batch_size] + cache_lengths = max_input_state.cache_lengths[:batch_size] + cache_lengths_tensor = max_input_state.cache_lengths_tensor[:batch_size] state = None if FLASH_INFER: @@ -284,8 +284,8 @@ def trace( max_k=max_total_tokens, ), input_lengths=input_lengths, - cache_lens=cache_lengths, - cache_lens_tensor=cache_lengths_tensor, + cache_lengths=cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, adapter_data=AdapterBatchData( meta=AdapterBatchMetadata( adapter_indices=max_input_state.adapter_data.meta.adapter_indices[:batch_size], @@ -307,8 +307,8 @@ def trace( cu_seqlen_prefill=None, input_lengths=input_lengths, input_lengths_tensor=input_state.seqlen.input_lengths, - cache_lens=cache_lengths, - cache_lens_tensor=cache_lengths_tensor, + cache_lengths=cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, state=input_state.state, ): # warmup @@ -356,8 +356,8 @@ def forward( block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - cache_lens: List[int], - cache_lens_tensor: torch.Tensor, + cache_lengths: List[int], + cache_lengths_tensor: torch.Tensor, max_s: int, adapter_data: AdapterBatchData, lm_head_indices: Optional[torch.Tensor] = None, @@ -367,8 +367,8 @@ def forward( pad_and_fill(self.input_state.slots, slots, SLOT_PAD_VALUE) pad_and_fill(self.input_state.seqlen.input_lengths, seqlen.input_lengths, 0) pad_and_fill(self.input_state.seqlen.cache_lengths, seqlen.cache_lengths, 0) - self.input_state.cache_lens[: len(cache_lens)] = cache_lens - pad_and_fill(self.input_state.cache_lens_tensor, cache_lens_tensor, 0) + self.input_state.cache_lengths[: len(cache_lengths)] = cache_lengths + pad_and_fill(self.input_state.cache_lengths_tensor, cache_lengths_tensor, 0) if FLASH_INFER: block_tables = block_tables_to_ragged( @@ -408,8 +408,8 @@ def forward( cu_seqlen_prefill=None, input_lengths=seqlen.input_lengths, input_lengths_tensor=self.input_state.seqlen.input_lengths, - cache_lens=self.input_state.cache_lens, - cache_lens_tensor=self.input_state.cache_lens_tensor, + cache_lengths=self.input_state.cache_lengths, + cache_lengths_tensor=self.input_state.cache_lengths_tensor, state=self.input_state.state, ): self.graph.replay() @@ -560,8 +560,8 @@ def forward( block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - cache_lens: List[int], - cache_lens_tensor: torch.Tensor, + cache_lengths: List[int], + cache_lengths_tensor: torch.Tensor, max_s: int, adapter_data: AdapterBatchData, lm_head_indices: Optional[torch.Tensor] = None, @@ -606,8 +606,8 @@ def forward( block_tables=block_tables, slots=slots, seqlen=seqlen, - cache_lens=cache_lens, - cache_lens_tensor=cache_lens_tensor, + cache_lengths=cache_lengths, + cache_lengths_tensor=cache_lengths_tensor, max_s=max_s, adapter_data=adapter_data, lm_head_indices=lm_head_indices, From 65076e2a13e12ee48db6977b8003776c7454d1e0 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 18 Oct 2024 15:33:11 -0700 Subject: [PATCH 29/34] Seqlen --- server/lorax_server/models/mllama.py | 15 ++++++++++++--- server/lorax_server/models/vlm_causal_lm.py | 17 +++++++++++++---- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/server/lorax_server/models/mllama.py b/server/lorax_server/models/mllama.py index aef0f0e43..d79916aad 100644 --- a/server/lorax_server/models/mllama.py +++ b/server/lorax_server/models/mllama.py @@ -2,6 +2,7 @@ from io import BytesIO from typing import Dict, Iterable, List, Optional, Tuple +from lorax_server.utils.attention.common import Seqlen import torch from opentelemetry import trace from PIL import Image @@ -181,7 +182,7 @@ def forward( block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices speculative_ids = batch.speculative_ids @@ -211,7 +212,7 @@ def forward( slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor cache_lengths_tensor = batch.cache_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices if cu_seqlen_prefill is None and self.max_past() is not None: @@ -219,6 +220,14 @@ def forward( # in a circular buffer mode. # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) + + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=None, + max_q=batch.max_input_length, + max_k=batch.max_current_length, + ) # TODO: cuda graph input_lengths = input_lengths + cache_lengths_tensor @@ -255,7 +264,7 @@ def forward( kv_cache=kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, diff --git a/server/lorax_server/models/vlm_causal_lm.py b/server/lorax_server/models/vlm_causal_lm.py index 734454e82..8c3755c3b 100644 --- a/server/lorax_server/models/vlm_causal_lm.py +++ b/server/lorax_server/models/vlm_causal_lm.py @@ -1,6 +1,7 @@ from io import BytesIO from typing import Iterable, List, Optional, Tuple, Type +from lorax_server.utils.attention.common import Seqlen import torch import torch.distributed from loguru import logger @@ -302,7 +303,7 @@ def forward( block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length speculative_ids = batch.speculative_ids @@ -330,7 +331,7 @@ def forward( slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor cache_lengths_tensor = batch.cache_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length if cu_seqlen_prefill is None and self.max_past() is not None: # In decode, not prefill, we're actually overwriting the KV-cache @@ -348,6 +349,14 @@ def forward( ): use_graph = True model = self.model_graph_wrapper + + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=None, + max_q=batch.max_input_length, + max_k=batch.max_current_length, + ) if not use_graph: input_lengths = input_lengths + cache_lengths_tensor @@ -373,7 +382,7 @@ def forward( kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, max_s=max_s, adapter_data=adapter_data, prefill_cache_indices=batch.prefill_cache_indices, @@ -391,7 +400,7 @@ def forward( kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, cache_lengths=batch.cache_lengths, cache_lengths_tensor=cache_lengths_tensor, max_s=max_s, From 787be034fc3602fa83cbced64d82c61c3e42d441 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 18 Oct 2024 15:41:16 -0700 Subject: [PATCH 30/34] Fixed llava next --- server/lorax_server/models/custom_modeling/llava_next.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/llava_next.py b/server/lorax_server/models/custom_modeling/llava_next.py index 2a32ca4c8..2ece2e821 100644 --- a/server/lorax_server/models/custom_modeling/llava_next.py +++ b/server/lorax_server/models/custom_modeling/llava_next.py @@ -16,6 +16,7 @@ from typing import List, Optional, Tuple +from lorax_server.utils.attention.common import Seqlen import torch import torch.utils.checkpoint from torch import nn @@ -168,7 +169,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -255,7 +256,7 @@ def forward( kv_cache=kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, max_s=max_s, prefill_cache_indices=None, cross_attention_states=None, From fcfa679ad6db0e4fccc68b55cac88ce64336b10d Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 18 Oct 2024 15:53:36 -0700 Subject: [PATCH 31/34] Mllama --- server/lorax_server/models/custom_modeling/mllama.py | 8 +++++--- server/lorax_server/models/mllama.py | 7 +++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/mllama.py b/server/lorax_server/models/custom_modeling/mllama.py index dc6477c9e..989c74c1f 100644 --- a/server/lorax_server/models/custom_modeling/mllama.py +++ b/server/lorax_server/models/custom_modeling/mllama.py @@ -16,6 +16,8 @@ from typing import List, Optional, Tuple +from lorax_server.utils.attention.common import Seqlen + import flash_attn_2_cuda import torch import torch.nn.functional as F @@ -875,7 +877,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor], @@ -895,7 +897,7 @@ def forward( indices = [] for index in image_indices: cu_q.append(offset) - length = input_lengths[index].item() + length = seqlen.input_lengths[index].item() assert index < cu_seqlen_prefill.shape[0] input_ids_offset = cu_seqlen_prefill[index] indices.extend(range(input_ids_offset, input_ids_offset + length)) @@ -947,7 +949,7 @@ def forward( kv_cache=kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, max_s=max_s, adapter_data=adapter_data, prefill_cache_indices=prefill_cache_indices, diff --git a/server/lorax_server/models/mllama.py b/server/lorax_server/models/mllama.py index d79916aad..78431f17e 100644 --- a/server/lorax_server/models/mllama.py +++ b/server/lorax_server/models/mllama.py @@ -3,6 +3,7 @@ from typing import Dict, Iterable, List, Optional, Tuple from lorax_server.utils.attention.common import Seqlen +import numpy as np import torch from opentelemetry import trace from PIL import Image @@ -151,6 +152,12 @@ def from_pb( # XXX: <|image|> token is actually out of bounds and bugs out the logit processors. batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(max=config.text_config.vocab_size - 1) + if isinstance(batch.input_ids, list): + if len(batch) > 1: + input_ids = np.concatenate(batch.input_ids, dtype=np.int64) + else: + input_ids = batch.input_ids[0] + batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) if image_inputs is not None: From df5ae306be8bdb9a727ae4aaaaa3f42070a8cfae Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 18 Oct 2024 16:08:37 -0700 Subject: [PATCH 32/34] Fixed embeddings --- server/lorax_server/models/flash_causal_lm.py | 13 ------------- server/lorax_server/server.py | 2 +- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 0e6a48d9c..ac137a859 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -353,19 +353,6 @@ def from_pb( adapter_meta=None, ) - @classmethod - def from_pb_embed( - self, - pb: generate_pb2.EmbedRequest, - tokenizer: PreTrainedTokenizerBase, - tokenizers: TokenizerManager, - processor, - config, - dtype, - device, - ) -> "FlashCausalLMBatch": - return self.from_pb(pb, tokenizer, tokenizers, None, None, dtype, device) - @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": if len(request_ids) == 0: diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index e1d2139eb..90cecac71 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -138,7 +138,7 @@ async def Embed(self, request: generate_pb2.EmbedRequest, context): if not self.model.supports_embeddings: raise ValueError("Model does not support embeddings") - batch = self.model.batch_type.from_pb_embed( + batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, self.model.tokenizers, From 401f1ae8e8ee855462924a48bc2adbc4c01bd731 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 18 Oct 2024 16:20:35 -0700 Subject: [PATCH 33/34] input_lengths -> seqlen --- .../custom_modeling/flash_cohere_modeling.py | 18 +++++++------ .../custom_modeling/flash_dbrx_modeling.py | 21 ++++++++------- .../custom_modeling/flash_gemma2_modeling.py | 17 ++++++------ .../custom_modeling/flash_gemma_modeling.py | 17 ++++++------ .../custom_modeling/flash_gpt2_modeling.py | 17 ++++++------ .../custom_modeling/flash_mistral_modeling.py | 20 +++++++------- .../custom_modeling/flash_mixtral_modeling.py | 22 ++++++++------- .../custom_modeling/flash_neox_modeling.py | 19 ++++++------- .../custom_modeling/flash_phi3_modeling.py | 18 +++++++------ .../custom_modeling/flash_phi_modeling.py | 17 ++++++------ .../custom_modeling/flash_qwen_modeling.py | 18 +++++++------ .../custom_modeling/flash_rw_modeling.py | 27 ++++++++++--------- .../flash_santacoder_modeling.py | 17 ++++++------ 13 files changed, 133 insertions(+), 115 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py b/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py index 8f3d98f7c..1c21fd135 100644 --- a/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py @@ -20,6 +20,8 @@ from typing import List, Optional, Tuple +from lorax_server.utils.attention.common import Seqlen + import dropout_layer_norm import rotary_emb import torch @@ -252,7 +254,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -303,7 +305,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -387,7 +389,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -402,7 +404,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -453,7 +455,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, ) -> torch.Tensor: @@ -474,7 +476,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -518,7 +520,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, @@ -531,7 +533,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) diff --git a/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py b/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py index db79bcdf4..84a7accd4 100644 --- a/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py @@ -15,6 +15,7 @@ from typing import Any, List, Optional, Tuple +from lorax_server.utils.attention.common import Seqlen import numpy as np import torch import torch.distributed @@ -394,7 +395,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -441,7 +442,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -475,7 +476,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -490,7 +491,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -892,7 +893,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -906,7 +907,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -947,7 +948,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, ) -> torch.Tensor: @@ -968,7 +969,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -1003,7 +1004,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, @@ -1016,7 +1017,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) diff --git a/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py index 1ad9e532c..5d2538fff 100644 --- a/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py @@ -20,6 +20,7 @@ from typing import List, Optional, Tuple +from lorax_server.utils.attention.common import Seqlen import torch import torch.distributed from torch import nn @@ -239,7 +240,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -283,7 +284,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -379,7 +380,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -394,7 +395,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -445,7 +446,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, ) -> torch.Tensor: @@ -466,7 +467,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -510,7 +511,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, @@ -524,7 +525,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) diff --git a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py index ef3328459..e07b6a97a 100644 --- a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py @@ -15,6 +15,7 @@ from typing import List, Optional, Tuple +from lorax_server.utils.attention.common import Seqlen import torch import torch.distributed from torch import nn @@ -245,7 +246,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -288,7 +289,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -393,7 +394,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -409,7 +410,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -458,7 +459,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, ) -> torch.Tensor: @@ -482,7 +483,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -509,7 +510,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, @@ -522,7 +523,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) diff --git a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py index 1845b3458..ad2c663c7 100644 --- a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py @@ -20,6 +20,7 @@ from typing import List, Optional, Tuple +from lorax_server.utils.attention.common import Seqlen import torch import torch.distributed from torch import nn @@ -149,7 +150,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -182,7 +183,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -255,7 +256,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -267,7 +268,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -322,7 +323,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, ) -> torch.Tensor: @@ -337,7 +338,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -361,7 +362,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, @@ -374,7 +375,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index cc6f93dc6..c4124b15b 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -20,6 +20,8 @@ from typing import List, Optional, Tuple +from lorax_server.utils.attention.common import Seqlen + # Flash attention imports import dropout_layer_norm import torch @@ -308,7 +310,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, prefill_cache_indices, @@ -358,7 +360,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -445,7 +447,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, prefill_cache_indices, @@ -461,7 +463,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, prefill_cache_indices, @@ -509,7 +511,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor], @@ -531,7 +533,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, prefill_cache_indices, @@ -581,7 +583,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, @@ -594,7 +596,7 @@ def forward( # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values max_s = min(self.max_past, max_s) - input_lengths = torch.clamp(input_lengths, max=self.max_past) + seqlen = torch.clamp(seqlen, max=self.max_past) inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( @@ -604,7 +606,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, prefill_cache_indices, diff --git a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py index 2a7622c99..d189b33b9 100644 --- a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py @@ -21,6 +21,8 @@ from typing import List, Optional, Tuple +from lorax_server.utils.attention.common import Seqlen + # Flash attention imports import dropout_layer_norm import numpy as np @@ -366,7 +368,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, prefill_cache_indices, @@ -382,7 +384,7 @@ def forward( kv_cache: The key-value cache. block_tables: The block tables for attention computation. slots: The number of slots. - input_lengths: The lengths of the input sequences. + seqlen: The lengths of the input sequences. max_s: The maximum sequence length. adapter_data: The adapter data. prefill_cache_indices: The indices for prefilling the cache. @@ -436,7 +438,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -835,7 +837,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, prefill_cache_indices, @@ -851,7 +853,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, prefill_cache_indices, @@ -896,7 +898,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor], @@ -918,7 +920,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, prefill_cache_indices, @@ -955,7 +957,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, @@ -968,7 +970,7 @@ def forward( # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values max_s = min(self.max_past, max_s) - input_lengths = torch.clamp(input_lengths, max=self.max_past) + seqlen = torch.clamp(seqlen, max=self.max_past) hidden_states = self.model( input_ids, @@ -977,7 +979,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, prefill_cache_indices, diff --git a/server/lorax_server/models/custom_modeling/flash_neox_modeling.py b/server/lorax_server/models/custom_modeling/flash_neox_modeling.py index 0b0ff025f..2e1d00829 100644 --- a/server/lorax_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_neox_modeling.py @@ -20,6 +20,7 @@ from typing import List, Optional, Tuple +from lorax_server.utils.attention.common import Seqlen import torch import torch.distributed from torch import nn @@ -122,7 +123,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -158,7 +159,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -222,7 +223,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): if self.use_parallel_residual: @@ -236,7 +237,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -260,7 +261,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -307,7 +308,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.embed_in(input_ids) @@ -327,7 +328,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -352,7 +353,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -364,7 +365,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py index b0c10b9bb..78a721b15 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py @@ -20,6 +20,8 @@ from typing import List, Optional, Tuple +from lorax_server.utils.attention.common import Seqlen + # Flash attention imports import dropout_layer_norm import torch @@ -254,7 +256,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -298,7 +300,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -381,7 +383,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -396,7 +398,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -444,7 +446,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, ) -> torch.Tensor: @@ -465,7 +467,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -500,7 +502,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, @@ -513,7 +515,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) diff --git a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py index 16610de75..82f731ceb 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py @@ -11,6 +11,7 @@ from typing import List, Optional, Tuple +from lorax_server.utils.attention.common import Seqlen import torch import torch.distributed from torch import nn @@ -140,7 +141,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -184,7 +185,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -262,7 +263,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -276,7 +277,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -327,7 +328,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, ) -> torch.Tensor: @@ -348,7 +349,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -382,7 +383,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, @@ -395,7 +396,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) diff --git a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py index 57f3cce21..525bd0ba3 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py @@ -6,6 +6,8 @@ from typing import List, Optional, Tuple +from lorax_server.utils.attention.common import Seqlen + # Flash attention imports import dropout_layer_norm import torch @@ -227,7 +229,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -271,7 +273,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -357,7 +359,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -372,7 +374,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -422,7 +424,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, ) -> torch.Tensor: @@ -443,7 +445,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -478,7 +480,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor] = None, @@ -491,7 +493,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) diff --git a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py index f1fff63ea..861bdedfa 100644 --- a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py @@ -1,5 +1,6 @@ from typing import List, Optional, Tuple +from lorax_server.utils.attention.common import Seqlen import torch import torch.distributed from torch import nn @@ -153,7 +154,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -198,7 +199,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -260,7 +261,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -310,7 +311,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -387,7 +388,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): if self.parallel_attn: @@ -401,7 +402,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -423,7 +424,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -470,7 +471,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): ln_attn, residual = self.ln_attn(hidden_states, residual) @@ -485,7 +486,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -541,7 +542,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) @@ -561,7 +562,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -587,7 +588,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -599,7 +600,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py b/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py index 7efa57536..690d8e589 100644 --- a/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py @@ -1,5 +1,6 @@ from typing import List, Optional, Tuple +from lorax_server.utils.attention.common import Seqlen import torch import torch.distributed from torch import nn @@ -220,7 +221,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.c_attn(hidden_states) @@ -258,7 +259,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -313,7 +314,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): hidden_states, residual = self.ln_1(hidden_states, residual) @@ -323,7 +324,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -377,7 +378,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.wte(input_ids) + self.wpe(position_ids) @@ -394,7 +395,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -418,7 +419,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -430,7 +431,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: From 1d9f5d46ff37bba0d968ca2fb606280f82c4a56a Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 18 Oct 2024 16:21:11 -0700 Subject: [PATCH 34/34] ruff --- .../custom_modeling/flash_cohere_modeling.py | 3 +- .../custom_modeling/flash_dbrx_modeling.py | 2 +- .../custom_modeling/flash_gemma2_modeling.py | 2 +- .../custom_modeling/flash_gemma_modeling.py | 2 +- .../custom_modeling/flash_gpt2_modeling.py | 2 +- .../custom_modeling/flash_llama_modeling.py | 3 +- .../custom_modeling/flash_mistral_modeling.py | 3 +- .../custom_modeling/flash_mixtral_modeling.py | 19 +- .../custom_modeling/flash_neox_modeling.py | 2 +- .../custom_modeling/flash_phi3_modeling.py | 3 +- .../custom_modeling/flash_phi_modeling.py | 2 +- .../custom_modeling/flash_qwen2_modeling.py | 3 +- .../custom_modeling/flash_qwen_modeling.py | 3 +- .../custom_modeling/flash_rw_modeling.py | 2 +- .../flash_santacoder_modeling.py | 2 +- .../models/custom_modeling/llava_next.py | 2 +- .../models/custom_modeling/mllama.py | 3 +- server/lorax_server/models/flash_causal_lm.py | 221 ++++++------------ server/lorax_server/models/mllama.py | 4 +- server/lorax_server/models/model.py | 11 +- server/lorax_server/models/vlm_causal_lm.py | 4 +- server/lorax_server/server.py | 6 +- server/lorax_server/utils/attention/common.py | 4 +- server/lorax_server/utils/attention/utils.py | 4 +- .../utils/flashinfer_attention.py | 20 +- server/lorax_server/utils/graph.py | 2 +- server/lorax_server/utils/paged_attention.py | 2 +- 27 files changed, 119 insertions(+), 217 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py b/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py index 1c21fd135..64dddfc36 100644 --- a/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py @@ -20,8 +20,6 @@ from typing import List, Optional, Tuple -from lorax_server.utils.attention.common import Seqlen - import dropout_layer_norm import rotary_emb import torch @@ -32,6 +30,7 @@ from lorax_server.adapters.weights import AdapterBatchData from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention +from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.layers import ( FastLayerNorm, MultiAdapterHead, diff --git a/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py b/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py index 84a7accd4..5c16c81df 100644 --- a/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py @@ -15,7 +15,6 @@ from typing import Any, List, Optional, Tuple -from lorax_server.utils.attention.common import Seqlen import numpy as np import torch import torch.distributed @@ -27,6 +26,7 @@ from lorax_server.adapters.weights import AdapterBatchData from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention +from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.layers import ( FastLayerNorm, FastLinear, diff --git a/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py index 5d2538fff..e51e26ed1 100644 --- a/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py @@ -20,7 +20,6 @@ from typing import List, Optional, Tuple -from lorax_server.utils.attention.common import Seqlen import torch import torch.distributed from torch import nn @@ -40,6 +39,7 @@ from lorax_server.layers.rotary import PositionRotaryEmbedding from lorax_server.layers.tensor_parallel import TensorParallelHead from lorax_server.utils import flash_attn, paged_attention +from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.layers import MultiAdapterHead, TensorParallelAdapterRowLinear, TensorParallelMultiAdapterLinear from lorax_server.utils.lora import ( DOWN_PROJ, diff --git a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py index e07b6a97a..139f7e822 100644 --- a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py @@ -15,7 +15,6 @@ from typing import List, Optional, Tuple -from lorax_server.utils.attention.common import Seqlen import torch import torch.distributed from torch import nn @@ -26,6 +25,7 @@ from lorax_server.adapters import AdapterBatchData from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention +from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.layers import ( PositionRotaryEmbedding, TensorParallelAdapterRowLinear, diff --git a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py index ad2c663c7..eeb6e8d38 100644 --- a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py @@ -20,7 +20,6 @@ from typing import List, Optional, Tuple -from lorax_server.utils.attention.common import Seqlen import torch import torch.distributed from torch import nn @@ -31,6 +30,7 @@ from lorax_server.adapters import AdapterBatchData from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention +from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.layers import ( FastLayerNorm, TensorParallelAdapterRowLinear, diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index e5ccd2617..524fa792d 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -20,8 +20,6 @@ from typing import List, Optional, Tuple -from lorax_server.utils.attention.common import Seqlen - # Flash attention imports import dropout_layer_norm import torch @@ -32,6 +30,7 @@ from lorax_server.adapters import AdapterBatchData from lorax_server.utils import flash_attn, paged_attention +from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.layers import ( MultiAdapterHead, PositionRotaryEmbedding, diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index c4124b15b..88b69dbb5 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -20,8 +20,6 @@ from typing import List, Optional, Tuple -from lorax_server.utils.attention.common import Seqlen - # Flash attention imports import dropout_layer_norm import torch @@ -32,6 +30,7 @@ from lorax_server.adapters import AdapterBatchData from lorax_server.utils import flash_attn, paged_attention +from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA from lorax_server.utils.layers import ( MultiAdapterHead, diff --git a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py index d189b33b9..08acee663 100644 --- a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py @@ -21,8 +21,6 @@ from typing import List, Optional, Tuple -from lorax_server.utils.attention.common import Seqlen - # Flash attention imports import dropout_layer_norm import numpy as np @@ -35,6 +33,7 @@ from lorax_server.adapters import AdapterBatchData from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention +from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA from lorax_server.utils.layers import ( FastLinear, @@ -189,13 +188,15 @@ def _load_gqa(config, prefix: str, weights): config.hidden_size, ], f"{list(weight.shape)} != {[(num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear(get_linear( - weight, - bias=None, - quantize=config.quantize, - weight_scale=weight_scale, - input_scale=input_scale, - )) + return TensorParallelColumnLinear( + get_linear( + weight, + bias=None, + quantize=config.quantize, + weight_scale=weight_scale, + input_scale=input_scale, + ) + ) def _load_experts(config, prefix, mat, weights): diff --git a/server/lorax_server/models/custom_modeling/flash_neox_modeling.py b/server/lorax_server/models/custom_modeling/flash_neox_modeling.py index 2e1d00829..cc6df3382 100644 --- a/server/lorax_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_neox_modeling.py @@ -20,7 +20,6 @@ from typing import List, Optional, Tuple -from lorax_server.utils.attention.common import Seqlen import torch import torch.distributed from torch import nn @@ -30,6 +29,7 @@ from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention +from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.layers import ( FastLayerNorm, PositionRotaryEmbedding, diff --git a/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py index 78a721b15..b0b48688d 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py @@ -20,8 +20,6 @@ from typing import List, Optional, Tuple -from lorax_server.utils.attention.common import Seqlen - # Flash attention imports import dropout_layer_norm import torch @@ -33,6 +31,7 @@ from lorax_server.adapters import AdapterBatchData from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention +from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.layers import ( MultiAdapterHead, PositionRotaryEmbedding, diff --git a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py index 82f731ceb..d600928b9 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py @@ -11,7 +11,6 @@ from typing import List, Optional, Tuple -from lorax_server.utils.attention.common import Seqlen import torch import torch.distributed from torch import nn @@ -20,6 +19,7 @@ from lorax_server.adapters import AdapterBatchData from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention +from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.layers import ( FastLayerNorm, MultiAdapterHead, diff --git a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py index 805c4576d..261c6cff5 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py @@ -6,8 +6,6 @@ from typing import List, Optional, Tuple -from lorax_server.utils.attention.common import Seqlen - # Flash attention imports import dropout_layer_norm import torch @@ -18,6 +16,7 @@ from lorax_server.adapters import AdapterBatchData from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention +from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.layers import ( MultiAdapterHead, PositionRotaryEmbedding, diff --git a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py index 525bd0ba3..e55c10544 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py @@ -6,8 +6,6 @@ from typing import List, Optional, Tuple -from lorax_server.utils.attention.common import Seqlen - # Flash attention imports import dropout_layer_norm import torch @@ -19,6 +17,7 @@ from lorax_server.adapters import AdapterBatchData from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention +from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.layers import ( MultiAdapterHead, PositionRotaryEmbedding, diff --git a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py index 861bdedfa..4f3b36765 100644 --- a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py @@ -1,6 +1,5 @@ from typing import List, Optional, Tuple -from lorax_server.utils.attention.common import Seqlen import torch import torch.distributed from torch import nn @@ -9,6 +8,7 @@ from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention +from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.layers import ( FastLayerNorm, PositionRotaryEmbedding, diff --git a/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py b/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py index 690d8e589..4e98b97a2 100644 --- a/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py @@ -1,6 +1,5 @@ from typing import List, Optional, Tuple -from lorax_server.utils.attention.common import Seqlen import torch import torch.distributed from torch import nn @@ -8,6 +7,7 @@ from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention +from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.layers import ( FastLayerNorm, TensorParallelColumnLinear, diff --git a/server/lorax_server/models/custom_modeling/llava_next.py b/server/lorax_server/models/custom_modeling/llava_next.py index 2ece2e821..bede2691e 100644 --- a/server/lorax_server/models/custom_modeling/llava_next.py +++ b/server/lorax_server/models/custom_modeling/llava_next.py @@ -16,7 +16,6 @@ from typing import List, Optional, Tuple -from lorax_server.utils.attention.common import Seqlen import torch import torch.utils.checkpoint from torch import nn @@ -32,6 +31,7 @@ load_text_model, load_vision_model, ) +from lorax_server.utils.attention.common import Seqlen def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): diff --git a/server/lorax_server/models/custom_modeling/mllama.py b/server/lorax_server/models/custom_modeling/mllama.py index 989c74c1f..c48448b36 100644 --- a/server/lorax_server/models/custom_modeling/mllama.py +++ b/server/lorax_server/models/custom_modeling/mllama.py @@ -16,8 +16,6 @@ from typing import List, Optional, Tuple -from lorax_server.utils.attention.common import Seqlen - import flash_attn_2_cuda import torch import torch.nn.functional as F @@ -36,6 +34,7 @@ FlashLlamaForCausalLM, FlashLlamaLayer, ) +from lorax_server.utils.attention.common import Seqlen # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index ac137a859..3952f01d1 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from typing import Any, ContextManager, Dict, List, Optional, Tuple, Type, Union -from lorax_server.utils.attention.common import Seqlen import numpy as np import torch import torch.distributed @@ -16,16 +15,15 @@ from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata from lorax_server.models.model import Model from lorax_server.models.types import ( - AlternativeTokens, Batch, GeneratedText, Generation, NextTokens, - PrefillTokens, ) from lorax_server.pb import generate_pb2 from lorax_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, create_merged_weight_files +from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.attention.utils import block_tables_to_ragged from lorax_server.utils.dist import MEMORY_FRACTION, MEMORY_WIGGLE_ROOM, initialize_torch_distributed from lorax_server.utils.graph import GraphCache @@ -33,7 +31,14 @@ from lorax_server.utils.segments import SegmentConcatBuilder, find_segments from lorax_server.utils.sources import HUB from lorax_server.utils.sources.hub import weight_files -from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER, PREFIX_CACHING, get_max_prefill_tokens, get_speculative_tokens, get_supports_chunking, warmup_mode +from lorax_server.utils.state import ( + BLOCK_SIZE, + FLASH_INFER, + get_max_prefill_tokens, + get_speculative_tokens, + get_supports_chunking, + warmup_mode, +) from lorax_server.utils.tokenizer import TokenizerManager from lorax_server.utils.weights import Weights @@ -103,7 +108,7 @@ class FlashCausalLMBatch(Batch): all_input_ids: List[List[int]] all_input_ids_tensor: torch.Tensor - # Lengths of all generations present in the batch + # Lengths of all generations present in the batch input_lengths: List[int] # size [b], containing the number of blocks that can be retrieved from the cache cache_lengths: List[int] @@ -136,12 +141,9 @@ def to_pb(self) -> generate_pb2.CachedBatch: size=len(self), max_tokens=self.num_blocks * BLOCK_SIZE, current_tokens=( - sum([len(i) for i in self.input_ids]) - if isinstance(self.input_ids, list) - else len(self.input_ids) + sum([len(i) for i in self.input_ids]) if isinstance(self.input_ids, list) else len(self.input_ids) ), ) - @classmethod def to_pb_embed(self, batch, embeddings) -> generate_pb2.EmbedResponse: @@ -179,7 +181,7 @@ def from_pb( batch_tokenized_inputs = tokenizer(batch_inputs, truncation=True, max_length=max_truncation)[ "input_ids" ] - + speculative_tokens = get_speculative_tokens() cache_lengths = [] @@ -213,12 +215,10 @@ def from_pb( prompt_lengths.append(prompt_length) cache_length = r.cache_len - assert ( - cache_length <= prompt_length - ), f"Prefix {cache_length} vs input {prompt_length}" + assert cache_length <= prompt_length, f"Prefix {cache_length} vs input {prompt_length}" if cache_length == prompt_length: assert False, "unreachable" - + # TODO(travis): double-check prefix caching # if PREFIX_CACHING: # prefix_len = r.prefix_len @@ -239,12 +239,8 @@ def from_pb( assert get_supports_chunking() assert input_length > 0 - postfix_ids = tokenized_input[ - cache_length : cache_length + input_length - ] - assert ( - len(postfix_ids) == input_length - ), "Rust and Python tokenizers are not aligned" + postfix_ids = tokenized_input[cache_length : cache_length + input_length] + assert len(postfix_ids) == input_length, "Rust and Python tokenizers are not aligned" else: # Use all the remaining ids postfix_ids = tokenized_input[cache_length:] @@ -310,9 +306,7 @@ def from_pb( for i, request_blocks in enumerate(block_tables): block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor = block_tables_tensor.to(device) - prompt_lengths_tensor = torch.tensor( - prompt_lengths, dtype=torch.int32, device=device - ) + prompt_lengths_tensor = torch.tensor(prompt_lengths, dtype=torch.int32, device=device) return cls( batch_id=pb.id, @@ -414,9 +408,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": request_input_length = self.input_lengths[idx] request_cache_length = self.cache_lengths[idx] max_input_length = max(max_input_length, request_input_length) - max_current_length = max( - max_current_length, request_cache_length + request_input_length - ) + max_current_length = max(max_current_length, request_cache_length + request_input_length) all_input_ids.append(self.all_input_ids[idx]) @@ -445,16 +437,11 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": # Copy to tensor (CPU) slot_indices[i] = cumulative_max_length - remaining_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) + remaining_tokens = stopping_criteria.max_new_tokens - stopping_criteria.current_tokens # Set slice slot_filtering_indices[ - self.slot_indices[idx] : self.slot_indices[idx] - + request_input_length - + remaining_tokens - - 1 + self.slot_indices[idx] : self.slot_indices[idx] + request_input_length + remaining_tokens - 1 ] = True cumulative_max_length += request_input_length + remaining_tokens - 1 @@ -489,9 +476,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": slot_indices = slot_indices.to(device) adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor( - adapter_segments, dtype=torch.int32, device=device - ) + adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) adapter_meta = AdapterBatchMetadata( adapter_indices=adapter_indices, adapter_set=adapter_set, @@ -559,20 +544,14 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch # `total_slots` is not used if any of the batches is prefilling total_slots += len(b.slots) if not b.prefilling else 0 num_blocks += b.num_blocks - speculative_length = ( - b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 - ) + speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 max_input_length = max(max_input_length, b.max_input_length) max_current_length = max(max_current_length, b.max_current_length) max_length = max( max_length, max( - prompt_length - + stopping_criteria.max_new_tokens - + speculative_length - for prompt_length, stopping_criteria in zip( - b.prompt_lengths, b.stopping_criterias - ) + prompt_length + stopping_criteria.max_new_tokens + speculative_length + for prompt_length, stopping_criteria in zip(b.prompt_lengths, b.stopping_criterias) ), ) prefilling = prefilling or b.prefilling @@ -592,18 +571,10 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch position_ids = batches[0].position_ids.new_empty(total_batch_size) slots = batches[0].slots.new_empty(total_slots) slot_indices = batches[0].slot_indices.new_empty(total_batch_size) - input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( - total_batch_size - ) - cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty( - total_batch_size - ) - total_indices_size = sum( - b.adapter_meta.adapter_indices.shape[0] for b in batches - ) - adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( - total_indices_size - ) + input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(total_batch_size) + cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(total_batch_size) + total_indices_size = sum(b.adapter_meta.adapter_indices.shape[0] for b in batches) + adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(total_indices_size) adapter_segment_builder = SegmentConcatBuilder() adapter_set = set() @@ -663,21 +634,14 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids slots[slots_start_index:slots_end_index] = batch.slots - slot_indices[start_index:end_index] = ( - batch.slot_indices + cumulative_slots - ) + slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor # Copy over adapter indices adapter_start_index = cumulative_adapter_indices_size - adapter_end_index = ( - cumulative_adapter_indices_size - + batch.adapter_meta.adapter_indices.shape[0] - ) - adapter_indices[adapter_start_index:adapter_end_index] = ( - batch.adapter_meta.adapter_indices - ) + adapter_end_index = cumulative_adapter_indices_size + batch.adapter_meta.adapter_indices.shape[0] + adapter_indices[adapter_start_index:adapter_end_index] = batch.adapter_meta.adapter_indices cumulative_adapter_indices_size = adapter_end_index adapter_set.update(batch.adapter_meta.adapter_set) adapter_segment_builder.concat( @@ -773,7 +737,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch max_blocks=max_blocks, adapter_meta=adapter_meta, ) - + def prepare_for_prefill(self): global SLIDING_WINDOW global SLIDING_WINDOW_BLOCKS @@ -821,20 +785,14 @@ def prepare_for_prefill(self): ): next_chunk_length = input_length # Position ids - request_position_ids = torch.arange( - cache_length, cache_length + input_length, dtype=torch.int32 - ) + request_position_ids = torch.arange(cache_length, cache_length + input_length, dtype=torch.int32) position_ids.append(request_position_ids) # Add cumulative lengths of all previous inputs cu_seqlen_prefill.append(cumulative_length + input_length) if not r.slots: - request_slots = [ - s - for b in blocks - for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) - ] + request_slots = [s for b in blocks for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)] else: request_slots = r.slots @@ -867,9 +825,7 @@ def prepare_for_prefill(self): dtype=torch.int64, ) ) - prefill_next_token_indices.append( - prefill_out_cumulative_length + input_length - 1 - ) + prefill_next_token_indices.append(prefill_out_cumulative_length + input_length - 1) prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) prefill_out_cumulative_length += input_length else: @@ -917,18 +873,12 @@ def prepare_for_prefill(self): prefill_cache_indices = prefill_cache_indices[0] self.prefill_cu_outlens = prefill_cu_outlens - cu_seqlen_prefill = torch.tensor( - cu_seqlen_prefill, device=device, dtype=torch.int32 - ) + cu_seqlen_prefill = torch.tensor(cu_seqlen_prefill, device=device, dtype=torch.int32) self.cu_seqlen_prefill = cu_seqlen_prefill self.position_ids = position_ids.to(device) self.slot_indices = slot_indices.to(device) - self.prefill_cache_indices = ( - prefill_cache_indices.to(device) if SLIDING_WINDOW is not None else None - ) - self.input_lengths_tensor = torch.tensor( - self.input_lengths, dtype=torch.int32, device=device - ) + self.prefill_cache_indices = prefill_cache_indices.to(device) if SLIDING_WINDOW is not None else None + self.input_lengths_tensor = torch.tensor(self.input_lengths, dtype=torch.int32, device=device) if all_prefill_logprobs: prefill_head_indices = None @@ -938,23 +888,15 @@ def prepare_for_prefill(self): prefill_next_token_indices = None else: prefill_head_indices = torch.cat(prefill_head_indices).to(device) - prefill_next_token_indices = torch.tensor( - prefill_next_token_indices, dtype=torch.int64, device=device - ) + prefill_next_token_indices = torch.tensor(prefill_next_token_indices, dtype=torch.int64, device=device) self.prefill_head_indices = prefill_head_indices self.prefill_next_token_indices = prefill_next_token_indices self.slots = torch.tensor(slots, dtype=torch.int64, device=device) - self.cache_lengths_tensor = torch.tensor( - self.cache_lengths, dtype=torch.int32, device=device - ) - adapter_indices = torch.cat(adapter_indices_list).to( - dtype=torch.int64, device=device - ) + self.cache_lengths_tensor = torch.tensor(self.cache_lengths, dtype=torch.int32, device=device) + adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor( - adapter_segments, dtype=torch.int32, device=device - ) + adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) self.adapter_meta = AdapterBatchMetadata( adapter_indices=adapter_indices, adapter_set=adapter_set, @@ -1049,7 +991,10 @@ def __init__( weights._set_config(model_id, config) self._supports_embeddings = embedding_dim is not None - if not (weights.has_tensor("lm_head.weight") or weights.has_tensor("language_model.lm_head.weight")) and not self._supports_embeddings: + if ( + not (weights.has_tensor("lm_head.weight") or weights.has_tensor("language_model.lm_head.weight")) + and not self._supports_embeddings + ): raise ValueError( "Model does not have lm head so it is presumed to be for embeddings." "No embedding_dim was provided so we cannot load the model." @@ -1365,10 +1310,7 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> prefill = batch.cu_seqlen_prefill is not None model = self.model use_graph = False - if ( - self.model_graph_wrapper is not None - and not prefill - ): + if self.model_graph_wrapper is not None and not prefill: if self.model_graph_wrapper.can_use_graph(batch, adapter_data): use_graph = True model = self.model_graph_wrapper @@ -1402,13 +1344,13 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> input_ids = new_input_ids position_ids = new_position_ids - + if cu_seqlen_prefill is None and self.max_past() is not None: # In decode, not prefill, we're actually overwriting the KV-cache # in a circular buffer mode. # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - + seqlen = Seqlen( input_lengths=input_lengths, cache_lengths=cache_lengths_tensor, @@ -1416,7 +1358,7 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> max_q=batch.max_input_length, max_k=batch.max_current_length, ) - + # Model Forward if not use_graph: # eager mode @@ -1517,7 +1459,7 @@ def generate_token( prefill_logprobs = None next_token_logits = out next_adapter_indices = batch.adapter_meta.adapter_indices - + finished_prefilling = True next_chunk_lengths = [] current_prefilling_mask = batch.prefilling_mask @@ -1536,13 +1478,9 @@ def generate_token( reversed(batch.input_lengths), reversed(batch.prompt_lengths), ): - remaining_prefill_tokens = max( - prompt_length - cache_length - input_length, 0 - ) + remaining_prefill_tokens = max(prompt_length - cache_length - input_length, 0) if remaining_prefill_tokens > 0: - next_chunk_length = max( - min(remaining_prefill_tokens, batch_budget), 1 - ) + next_chunk_length = max(min(remaining_prefill_tokens, batch_budget), 1) batch_budget -= next_chunk_length finished_prefilling = False next_prefilling_mask.append(True) @@ -1645,9 +1583,7 @@ def generate_token( # Logprobs generated by the model are for the next token # So we need to translate the id tensor by 1 - ids = batch.all_input_ids_tensor[ - i, cache_length + 1 : cache_length + input_length + 1 - ] + ids = batch.all_input_ids_tensor[i, cache_length + 1 : cache_length + input_length + 1] if len(batch) > 1: prefill_tokens_indices[out_start_index:out_end_index] = ids else: @@ -1657,9 +1593,7 @@ def generate_token( if not request_is_prefilling: # Only save tokens if we are done prefilling for this request for j in range(n_accepted_ids): - batch.all_input_ids_tensor[i, cache_length + input_length + j] = ( - next_input_ids[index + j] - ) + batch.all_input_ids_tensor[i, cache_length + input_length + j] = next_input_ids[index + j] batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] @@ -1684,7 +1618,7 @@ def generate_token( prefill_logprobs = torch.gather(prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)) # GPU <-> CPU sync prefill_logprobs = prefill_logprobs.view(-1).tolist() - + # Does a GPU <-> CPU sync internally if prefill and finished_prefilling: # adjust segment lengths to account for all request lengths being 1 during decoding @@ -1703,7 +1637,7 @@ def generate_token( if return_alternatives: alternative_token_logprobs = alternative_token_logprobs.tolist() alternative_token_ids = alternative_token_ids.tolist() - + # Update values if we need to continue prefilling # This represents the `else` case of the `Update values` if above # but since this require the `next_token_ids` to be on CPU, it is better to do it here @@ -1733,9 +1667,7 @@ def generate_token( if request_prefilling: next_cache_length = cache_length + input_length # Get new prompt IDs to prefill - postfix_ids = all_input_ids[ - next_cache_length : next_cache_length + next_chunk_length - ] + postfix_ids = all_input_ids[next_cache_length : next_cache_length + next_chunk_length] else: # This request is done prefilling, the new id is the one selected the sampling method postfix_ids = [next_token_id] @@ -1743,7 +1675,7 @@ def generate_token( all_postfix_ids.append(postfix_ids) batch.input_ids = all_postfix_ids - + # Results generations: List[Generation] = [] stopped = not is_warmup @@ -1815,7 +1747,7 @@ def generate_token( # request_alternative_token_texts, # ) # all_alternative_tokens.append(alternative_tokens) - + # Compute logprobs first as, even though we might skip the token, # it can still be required to compute the logprobs # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need @@ -1831,25 +1763,17 @@ def generate_token( # We need to remove it out_end_index -= 1 - request_prefill_logprobs = prefill_logprobs[ - out_start_index:out_end_index - ] + request_prefill_logprobs = prefill_logprobs[out_start_index:out_end_index] # Logprobs generated by the model are for the next token # So we need to translate the id tensor by 1 - prefill_token_ids = all_input_ids[ - cache_length + 1 : cache_length + input_length + 1 - ] + prefill_token_ids = all_input_ids[cache_length + 1 : cache_length + input_length + 1] past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] if past_prefill_logprob_tokens is None: # add nan for cached prompt tokens/first token - request_prefill_logprobs = [float("nan")] * ( - cache_length + 1 - ) + request_prefill_logprobs - prefill_token_ids = ( - all_input_ids[: cache_length + 1] + prefill_token_ids - ) + request_prefill_logprobs = [float("nan")] * (cache_length + 1) + request_prefill_logprobs + prefill_token_ids = all_input_ids[: cache_length + 1] + prefill_token_ids prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, @@ -1865,9 +1789,7 @@ def generate_token( all_alternative_tokens, ) if past_prefill_logprob_tokens is not None: - prefill_logprob_tokens = ( - past_prefill_logprob_tokens + prefill_logprob_tokens - ) + prefill_logprob_tokens = past_prefill_logprob_tokens + prefill_logprob_tokens batch.prefill_logprob_tokens[i] = prefill_logprob_tokens else: @@ -1914,9 +1836,7 @@ def generate_token( stopped = stopped and current_stopped _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] - _next_token_logprobs = next_token_logprobs[ - index : index + n_accepted_ids - left - ] + _next_token_logprobs = next_token_logprobs[index : index + n_accepted_ids - left] # Shard generations # All generations will be appended in the rust sharded client @@ -1925,11 +1845,8 @@ def generate_token( # Decode generated tokens output_text, _, _ = self.decode_token( all_input_ids, - prefix_offset=len(all_input_ids) - - stopping_criteria.current_tokens - - 1, - read_offset=len(all_input_ids) - - stopping_criteria.current_tokens, + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, skip_special_tokens=True, ) generated_text = GeneratedText( @@ -1982,7 +1899,7 @@ def generate_token( ) generations.append(generation) - + # advance the FSM for each accepted token (as we may have more than one from speculative decoding) for next_token_id in _next_token_ids: batch.next_token_chooser.next_state(i, next_token_id) diff --git a/server/lorax_server/models/mllama.py b/server/lorax_server/models/mllama.py index 78431f17e..85facb25a 100644 --- a/server/lorax_server/models/mllama.py +++ b/server/lorax_server/models/mllama.py @@ -2,7 +2,6 @@ from io import BytesIO from typing import Dict, Iterable, List, Optional, Tuple -from lorax_server.utils.attention.common import Seqlen import numpy as np import torch from opentelemetry import trace @@ -13,6 +12,7 @@ from lorax_server.models.vlm_causal_lm import VlmCausalLM, VlmCausalLMBatch from lorax_server.pb import generate_pb2 +from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.attention.utils import block_tables_to_ragged from lorax_server.utils.state import PREFIX_CACHING from lorax_server.utils.tokenizer import TokenizerManager @@ -227,7 +227,7 @@ def forward( # in a circular buffer mode. # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - + seqlen = Seqlen( input_lengths=input_lengths, cache_lengths=cache_lengths_tensor, diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index ec2da1a2e..f5e9dbe99 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -17,7 +17,13 @@ load_and_merge_adapters, ) from lorax_server.utils.sources import HUB -from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER, CHUNKED_PREFILL, get_speculative_tokens, set_supports_chunking +from lorax_server.utils.state import ( + BLOCK_SIZE, + CHUNKED_PREFILL, + FLASH_INFER, + get_speculative_tokens, + set_supports_chunking, +) from lorax_server.utils.tokenizer import TokenizerManager from lorax_server.utils.weights import shard_on_dim @@ -80,8 +86,7 @@ def __init__( if supports_chunking: if speculation_tokens != 0: logger.warning( - "Chunked prefill does not support speculation yet. " - "Chunked prefill will be disabled", + "Chunked prefill does not support speculation yet. " "Chunked prefill will be disabled", ) supports_chunking = False if not FLASH_INFER: diff --git a/server/lorax_server/models/vlm_causal_lm.py b/server/lorax_server/models/vlm_causal_lm.py index 8c3755c3b..7c4a3b543 100644 --- a/server/lorax_server/models/vlm_causal_lm.py +++ b/server/lorax_server/models/vlm_causal_lm.py @@ -1,7 +1,6 @@ from io import BytesIO from typing import Iterable, List, Optional, Tuple, Type -from lorax_server.utils.attention.common import Seqlen import torch import torch.distributed from loguru import logger @@ -17,6 +16,7 @@ block_tables_to_ragged, ) from lorax_server.pb import generate_pb2 +from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.state import PREFIX_CACHING from lorax_server.utils.tokenizer import TokenizerManager @@ -349,7 +349,7 @@ def forward( ): use_graph = True model = self.model_graph_wrapper - + seqlen = Seqlen( input_lengths=input_lengths, cache_lengths=cache_lengths_tensor, diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 90cecac71..a75906ea6 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -75,7 +75,7 @@ async def FilterBatch(self, request, context): async def Warmup(self, request: generate_pb2.WarmupRequest, context): set_max_prefill_tokens(request.max_prefill_tokens) - + batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, @@ -104,9 +104,7 @@ async def Prefill(self, request: generate_pb2.PrefillRequest, context): if request.HasField("cached_batch"): cached_batch = self.cache.pop(request.cached_batch.id) if cached_batch is None: - raise ValueError( - f"Batch ID {request.cached_batch.id} not found in cache." - ) + raise ValueError(f"Batch ID {request.cached_batch.id} not found in cache.") batch = self.model.batch_type.concatenate([cached_batch, batch]) generations, next_batch = self.model.generate_token(batch) diff --git a/server/lorax_server/utils/attention/common.py b/server/lorax_server/utils/attention/common.py index 61b7cc502..5afc573e2 100644 --- a/server/lorax_server/utils/attention/common.py +++ b/server/lorax_server/utils/attention/common.py @@ -1,7 +1,8 @@ from dataclasses import dataclass -import torch from typing import Optional +import torch + @dataclass class Seqlen: @@ -50,4 +51,3 @@ def __init__( def clamp(self, max): self.input_lengths = torch.clamp(self.input_lengths, max=max) return self - \ No newline at end of file diff --git a/server/lorax_server/utils/attention/utils.py b/server/lorax_server/utils/attention/utils.py index 80ffcbda2..8292be916 100644 --- a/server/lorax_server/utils/attention/utils.py +++ b/server/lorax_server/utils/attention/utils.py @@ -10,9 +10,7 @@ def block_tables_to_ragged( assert len(input_lengths) == len(cache_lengths) total_len = sum(input_lengths) + sum(cache_lengths) - block_tables_ragged = torch.empty( - total_len, dtype=torch.int32, device=block_tables.device - ) + block_tables_ragged = torch.empty(total_len, dtype=torch.int32, device=block_tables.device) offset = 0 for i, (input_length, cache_length) in enumerate(zip(input_lengths, cache_lengths)): diff --git a/server/lorax_server/utils/flashinfer_attention.py b/server/lorax_server/utils/flashinfer_attention.py index c8f8d4c00..2accc60ca 100644 --- a/server/lorax_server/utils/flashinfer_attention.py +++ b/server/lorax_server/utils/flashinfer_attention.py @@ -53,9 +53,7 @@ def use_prefill_with_paged_kv_state( `attention` function while the context manager is active. """ - indptr = torch.zeros( - input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 - ) + indptr = torch.zeros(input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32) # Round up to page size and then calculate the cumulative sum to get # the indices into the block table. torch.add(input_lengths, page_size - 1, out=indptr[1:]) @@ -64,13 +62,9 @@ def use_prefill_with_paged_kv_state( # Get the lengths of the last page in a block. if page_size == 1: - last_page_len = torch.ones( - input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device - ) + last_page_len = torch.ones(input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device) else: - last_page_len = torch.empty( - input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device - ) + last_page_len = torch.empty(input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device) torch.sub(input_lengths, 1, out=last_page_len) last_page_len.remainder_(page_size) last_page_len += 1 @@ -198,9 +192,7 @@ def use_decode_state( `state` and parameters. This state will be used by all calls to the `paged_attention` function while the context manager is active. """ - indptr = torch.zeros( - input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 - ) + indptr = torch.zeros(input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32) # Round up to page size and then calculate the cumulative sum to get # the indices into the block table. torch.add(input_lengths, page_size - 1, out=indptr[1:]) @@ -208,9 +200,7 @@ def use_decode_state( indptr[1:].cumsum_(-1) # Get the lengths of the last page in a block. - last_page_len = torch.empty( - input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device - ) + last_page_len = torch.empty(input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device) torch.sub(input_lengths, 1, out=last_page_len) last_page_len.remainder_(page_size) last_page_len += 1 diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index dd74f6b7f..387ddf86c 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -7,7 +7,6 @@ from statistics import median from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple -from lorax_server.utils.attention.common import Seqlen import numpy as np import torch from loguru import logger @@ -17,6 +16,7 @@ from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata from lorax_server.adapters.lora import BatchLoraWeights, RankSegments from lorax_server.adapters.types import LORA +from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.attention.utils import block_tables_to_ragged from lorax_server.utils.sgmv import BGMV_MAX_RANK from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER diff --git a/server/lorax_server/utils/paged_attention.py b/server/lorax_server/utils/paged_attention.py index 60bee0e43..fa7ebfeb9 100644 --- a/server/lorax_server/utils/paged_attention.py +++ b/server/lorax_server/utils/paged_attention.py @@ -1,8 +1,8 @@ from typing import Optional -from lorax_server.utils.attention.common import Seqlen import torch +from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.import_utils import SYSTEM from lorax_server.utils.state import FLASH_INFER