diff --git a/router/src/scheduler.rs b/router/src/scheduler.rs index 08b274f3a..a847936f0 100644 --- a/router/src/scheduler.rs +++ b/router/src/scheduler.rs @@ -198,6 +198,9 @@ struct AdapterSchedulerState { /// Speculation amount speculate: u32, + /// Prefix caching + prefix_caching: bool, + /// Paged Attention Block Allocation block_allocator: Option, } @@ -239,6 +242,7 @@ impl AdapterSchedulerState { block_size, window_size, speculate, + prefix_caching, block_allocator, } } @@ -364,17 +368,21 @@ impl AdapterSchedulerState { }; decode_tokens += max_new_tokens; - if prefill_tokens > prefill_token_budget - || (prefill_tokens + decode_tokens + self.speculate) > token_budget - { - // Entry is over budget - // Add it back to the front - tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); - self.queues_state - .lock() - .await - .push_front(&adapter, id, entry); - break; + // If we're prefix caching, this check could be under-estimating the number of available blocks + // due to shared prefixes, so we'll let the block allocator determine whether we have enough space. + if !self.prefix_caching { + if prefill_tokens > prefill_token_budget + || (prefill_tokens + decode_tokens + self.speculate) > token_budget + { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); + self.queues_state + .lock() + .await + .push_front(&adapter, id, entry); + break; + } } let tokens = entry.request.input_length() diff --git a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py index f66404e6d..abbda50da 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py @@ -128,7 +128,11 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq"]: + input_scale, weight_scale = None, None + if isinstance(weight, tuple): + weight, input_scale, weight_scale = weight + + if config.quantize not in ["gptq", "awq", "fp8"]: weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads @@ -142,7 +146,15 @@ def _load_gqa(config, prefix: str, weights): w = [weights.get_sharded(f"{p}.bias", dim=0) for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]] bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device) - return TensorParallelColumnLinear(get_linear(weight, bias=bias, quantize=config.quantize)) + return TensorParallelColumnLinear( + get_linear( + weight, + bias=bias, + quantize=config.quantize, + weight_scale=weight_scale, + input_scale=input_scale, + ) + ) class FlashQwen2Attention(torch.nn.Module): diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 73fcfea1b..99821bed7 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1158,10 +1158,12 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> if ( self.model_graph_wrapper is not None and not prefill - and self.model_graph_wrapper.can_use_graph(batch, adapter_data) ): - use_graph = True - model = self.model_graph_wrapper + if self.model_graph_wrapper.can_use_graph(batch, adapter_data): + use_graph = True + model = self.model_graph_wrapper + else: + logger.info("CUDA graphs enabled but batch is incompatible, falling back to eager mode.") input_ids = batch.input_ids position_ids = batch.position_ids @@ -1194,7 +1196,7 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> if not use_graph: # eager mode input_lengths = input_lengths + prefix_lens_tensor - if PREFIX_CACHING: + if FLASH_INFER: block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths, diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index 169ca543c..3a840f9ce 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -43,7 +43,7 @@ # Include 0 to ensure we can use cuda graphs without adapters # TODO(travis): use padding to allow for more ranks without increasing memory usage -CACHED_MAX_RANKS = [0, 8, 16, 32, 64, 96, 128] +CACHED_MAX_RANKS = [0, 8, 16, 32, 64, 128] CACHED_MAX_RANKS = [r for r in CACHED_MAX_RANKS if r <= MAX_RANK] _allowed_ranks = set(CACHED_MAX_RANKS) @@ -108,7 +108,7 @@ def get_max_graph_state( input_ids = torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device) position_ids = torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int32, device=device) slots = torch.full((MAX_BATCH_SIZE,), SLOT_PAD_VALUE, dtype=torch.int64, device=device) - input_lengths = torch.ones((MAX_BATCH_SIZE,), dtype=torch.int32, device=device) + input_lengths = torch.full((MAX_BATCH_SIZE,), max_total_tokens, dtype=torch.int32, device=device) prefix_lens = [0] * MAX_BATCH_SIZE prefix_lens_tensor = torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int32, device=device)