From 7306d49a90a314ca2ece11f47c2249ee7b7c2d3d Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 13 May 2024 21:18:18 -0700 Subject: [PATCH] Use SGMV for prefill BGMV for decode (#464) --- .vscode/settings.json | 4 +- server/lorax_server/adapters/lora.py | 145 +++- server/lorax_server/adapters/medusa.py | 4 +- server/lorax_server/adapters/weights.py | 40 +- server/lorax_server/models/causal_lm.py | 3 +- server/lorax_server/models/flash_causal_lm.py | 3 +- server/lorax_server/models/flash_cohere.py | 4 + server/lorax_server/models/flash_dbrx.py | 4 + server/lorax_server/models/flash_gemma.py | 4 + server/lorax_server/models/flash_gpt2.py | 4 + server/lorax_server/models/flash_llama.py | 4 + server/lorax_server/models/flash_mistral.py | 4 + server/lorax_server/models/flash_mixtral.py | 4 + server/lorax_server/models/flash_phi.py | 4 + server/lorax_server/models/flash_phi3.py | 4 + server/lorax_server/models/flash_qwen.py | 4 + server/lorax_server/models/flash_qwen2.py | 4 + server/lorax_server/models/model.py | 16 +- server/lorax_server/utils/graph.py | 91 ++- server/lorax_server/utils/layers.py | 71 +- server/lorax_server/utils/sgmv.py | 1 + .../punica_kernels/bgmv/bgmv_config.h | 47 +- .../punica_kernels/punica_ops.cc | 655 +++++++++--------- server/tests/utils/test_lora.py | 21 +- 24 files changed, 707 insertions(+), 438 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 2467c7616..7c719a59e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -48,6 +48,8 @@ "numeric": "cpp", "tuple": "cpp", "type_traits": "cpp", - "atomic": "cpp" + "atomic": "cpp", + "__locale": "cpp", + "ios": "cpp" } } \ No newline at end of file diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index 03caf55c7..b17214093 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -9,7 +9,14 @@ from lorax_server.adapters.config import AdapterConfig, ModuleMap from lorax_server.adapters.types import LORA from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights -from lorax_server.utils.sgmv import MAX_RANK_CUSTOM, get_tmp_tensors, orient_for_rank, pad_rank +from lorax_server.utils.sgmv import ( + BGMV_MAX_RANK, + MAX_RANK_CUSTOM, + get_tmp_tensors, + orient_for_rank, + pad_rank, + use_cutlass_shrink, +) if TYPE_CHECKING: from lorax_server.models.model import Model @@ -87,15 +94,49 @@ def __init__( self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1 self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 + self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r) + self._is_transposed = False + # [num_layers, hidden_size, r] weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] - self.weights_a = torch.stack(weights_a) + self._weights_a = torch.stack(weights_a) # [num_layers, r, hidden_size] - self.weights_b = torch.stack(weights_b) + self._weights_b = torch.stack(weights_b) self.adapter_config = adapter_config + @property + def weights_a(self) -> torch.Tensor: + if self._is_transposed: + self._transpose_weights() + return self._weights_a + + @property + def weights_b(self) -> torch.Tensor: + if self._is_transposed: + self._transpose_weights() + return self._weights_b + + @property + def weights_a_t(self) -> torch.Tensor: + if not self._is_transposed: + self._transpose_weights() + return self._weights_a + + @property + def weights_b_t(self) -> torch.Tensor: + if not self._is_transposed: + self._transpose_weights() + return self._weights_b + + def _transpose_weights(self): + if self._use_cutlass_shrink: + # If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation + self._weights_a = self._weights_a.transpose(1, 2).contiguous() + self._weights_b = self._weights_b.transpose(1, 2).contiguous() + self._is_transposed = not self._is_transposed + @classmethod def get_batch_type(cls) -> BatchAdapterWeights: return BatchLoraWeights @@ -162,13 +203,19 @@ def load( @dataclass class RankSegments: rank: int - tmp_shrink: torch.Tensor - tmp_expand: torch.Tensor + lora_a_ptr: torch.Tensor lora_b_ptr: torch.Tensor + + # prefill (sgmv) + tmp_shrink: torch.Tensor + tmp_expand: torch.Tensor segment_starts: torch.Tensor segment_ends: torch.Tensor + # decode (bgmv) + indices: torch.Tensor + @dataclass class BatchLoraWeights(BatchAdapterWeights): @@ -176,6 +223,7 @@ class BatchLoraWeights(BatchAdapterWeights): lora_b: Dict[int, torch.Tensor] adapter_index_configs: Dict[int, LoraConfig] rank_data: Dict[int, RankSegments] + use_sgmv: bool def has_adapter(self, adapter_index: int) -> bool: return adapter_index in self.adapter_index_configs @@ -188,7 +236,9 @@ def key(cls) -> str: return LORA @classmethod - def load(self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMetadata) -> "BatchLoraWeights": + def load( + self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMetadata, prefill: bool + ) -> "BatchLoraWeights": adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)} first_weights = list(adapter_weights.values())[0] @@ -196,28 +246,53 @@ def load(self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMet segment_indices = meta.segment_indices lora_a = {idx: adapter_weights[idx].weights_a for idx in segment_indices if idx in adapter_weights} - lora_a_ptr = torch.tensor( - [ - (adapter_weights[idx].weights_a.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) - for idx in segment_indices - ], - dtype=torch.int64, - device=device, - ) lora_b = {idx: adapter_weights[idx].weights_b for idx in segment_indices if idx in adapter_weights} - lora_b_ptr = torch.tensor( - [ - (adapter_weights[idx].weights_b.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) - for idx in segment_indices - ], - dtype=torch.int64, - device=device, - ) + + max_rank = max(adapter_weights[idx].lora_a_r for idx in segment_indices if idx in adapter_weights) + + if prefill or max_rank > BGMV_MAX_RANK: + use_sgmv = True + lora_a_ptr = torch.tensor( + [ + (adapter_weights[idx].weights_a.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) + for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) + lora_b_ptr = torch.tensor( + [ + (adapter_weights[idx].weights_b.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) + for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) + else: + use_sgmv = False + lora_a_ptr = torch.tensor( + [ + (adapter_weights[idx].weights_a_t.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) + for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) + lora_b_ptr = torch.tensor( + [ + (adapter_weights[idx].weights_b_t.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) + for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) adapter_index_configs = { idx: adapter_weights[idx].adapter_config for idx in segment_indices if idx in adapter_weights } + adapter_to_segment = {v: k for k, v in enumerate(segment_indices)} + rank_indices = defaultdict(list) for segment_idx, adapter_idx in enumerate(segment_indices): if adapter_idx not in adapter_weights: @@ -226,17 +301,32 @@ def load(self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMet rank_data = {} for rank, indices in rank_indices.items(): - lora_a_ptr_indices = lora_a_ptr[indices] - tmp_shrink, tmp_expand = get_tmp_tensors(lora_a_ptr_indices.size(0), rank, device) + tmp_shrink = None + tmp_expand = None + segment_starts = None + segment_ends = None + batch_indices = None + + if use_sgmv: + lora_a_ptr_indices = lora_a_ptr[indices] + tmp_shrink, tmp_expand = get_tmp_tensors(lora_a_ptr_indices.size(0), rank, device) + segment_starts = meta.adapter_segments[indices] + segment_ends = meta.adapter_segments[[i + 1 for i in indices]] + else: + rank_indices = set(indices) + batch_indices = [adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()] + batch_indices = [idx if idx in rank_indices else -1 for idx in batch_indices] + batch_indices = torch.tensor(batch_indices, dtype=torch.int64, device=device) rank_data[rank] = RankSegments( rank=rank, tmp_shrink=tmp_shrink, tmp_expand=tmp_expand, - lora_a_ptr=lora_a_ptr_indices, + lora_a_ptr=lora_a_ptr[indices], lora_b_ptr=lora_b_ptr[indices], - segment_starts=meta.adapter_segments[indices], - segment_ends=meta.adapter_segments[[i + 1 for i in indices]], + segment_starts=segment_starts, + segment_ends=segment_ends, + indices=batch_indices, ) return BatchLoraWeights( @@ -244,6 +334,7 @@ def load(self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMet lora_b=lora_b, adapter_index_configs=adapter_index_configs, rank_data=rank_data, + use_sgmv=use_sgmv, ) diff --git a/server/lorax_server/adapters/medusa.py b/server/lorax_server/adapters/medusa.py index 372a3585c..90361a3ba 100644 --- a/server/lorax_server/adapters/medusa.py +++ b/server/lorax_server/adapters/medusa.py @@ -269,7 +269,9 @@ def __call__(self, x, lm_head): return lm_head(x) @classmethod - def load(cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata") -> "BatchMedusaWeights": + def load( + cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata", prefill: bool + ) -> "BatchMedusaWeights": adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, MedusaWeights)} default_medusa = adapter_weights.get(0) diff --git a/server/lorax_server/adapters/weights.py b/server/lorax_server/adapters/weights.py index 7eee90666..c42431912 100644 --- a/server/lorax_server/adapters/weights.py +++ b/server/lorax_server/adapters/weights.py @@ -45,7 +45,9 @@ def key(cls) -> str: pass @abstractclassmethod - def load(cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata") -> "BatchAdapterWeights": + def load( + cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata", prefill: bool + ) -> "BatchAdapterWeights": pass @@ -70,7 +72,7 @@ def max_speculative_tokens(self) -> int: def is_empty(self) -> bool: return len(self.adapter_weights) == 0 - def get_data(self, meta: AdapterBatchMetadata) -> Dict[str, BatchAdapterWeights]: + def get_data(self, meta: AdapterBatchMetadata, prefill: bool) -> Dict[str, BatchAdapterWeights]: # bucket adapters by batch class adapter_batch_types: Dict[Type[BatchAdapterWeights], Dict[int, AdapterWeights]] = defaultdict(dict) for adapter_index, adapter_weights in self.adapter_weights.items(): @@ -78,7 +80,7 @@ def get_data(self, meta: AdapterBatchMetadata) -> Dict[str, BatchAdapterWeights] batch_data = {} for batch_type, adapter_weights in adapter_batch_types.items(): - batch_data[batch_type.key()] = batch_type.load(adapter_weights, meta) + batch_data[batch_type.key()] = batch_type.load(adapter_weights, meta, prefill) return batch_data @@ -89,22 +91,40 @@ class AdapterBatchData: # layer type -> adapter type -> batch weight data data: Dict[str, Dict[str, BatchAdapterWeights]] + prefill: bool + @staticmethod - def from_meta(meta: AdapterBatchMetadata, weights: Dict[str, LayerAdapterWeights]) -> "AdapterBatchData": + def from_meta( + meta: AdapterBatchMetadata, weights: Dict[str, LayerAdapterWeights], prefill: bool + ) -> "AdapterBatchData": data = {} for k, v in weights.items(): if v.is_empty(): continue - data[k] = v.get_data(meta) - return AdapterBatchData(meta=meta, data=data) + data[k] = v.get_data(meta, prefill) + return AdapterBatchData(meta=meta, data=data, prefill=prefill) def ranks(self) -> Set[int]: # TODO(travis): refactor to be less coupled to lora implementation - lora_data = self.data.get(LORA) - if lora_data is None: - return set() + ranks = set() + for layer_data in self.data.values(): + lora_data = layer_data.get(LORA) + if lora_data is None: + continue + + for rank_data in lora_data.rank_data.values(): + ranks.add(rank_data.rank) + + return ranks + + def layer_names(self) -> Set[str]: + return set(self.data.keys()) - return set(rank_data.rank for layer_data in self.data.values() for rank_data in lora_data.rank_data.values()) + def adapter_keys(self) -> Set[str]: + adapter_keys = set() + for layer_data in self.data.values(): + adapter_keys.update(layer_data.keys()) + return adapter_keys @property def max_rank(self) -> int: diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index 2bd95c928..da1bef8ce 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -588,7 +588,8 @@ def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Option # Assign pointers to LoRA weights # TODO(travis): don't update this if indices haven't changed - adapter_data = AdapterBatchData.from_meta(batch.adapter_meta, self.batched_lora_weights) + # Use prefill=True in all cases to force use of SGMV, as the batch is heterogenous + adapter_data = AdapterBatchData.from_meta(batch.adapter_meta, self.layer_to_adapter_weights, prefill=True) logits, past = self.forward( batch.input_ids, diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 65ba597ce..e63bdd025 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -796,6 +796,7 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int): self.model, self.device, self.adapter_layers, + self.default_traced_adapter_layers, max_total_tokens, self.sliding_window_blocks, ) @@ -951,7 +952,7 @@ def generate_token( # Assign pointers to adapter weights # TODO(travis): don't update this if indices haven't changed - adapter_data = AdapterBatchData.from_meta(adapter_meta, self.batched_lora_weights) + adapter_data = AdapterBatchData.from_meta(adapter_meta, self.layer_to_adapter_weights, prefill) out, speculative_logits = self._try_generate_token(batch, adapter_data) diff --git a/server/lorax_server/models/flash_cohere.py b/server/lorax_server/models/flash_cohere.py index 42a03fbee..a4f0f98c5 100644 --- a/server/lorax_server/models/flash_cohere.py +++ b/server/lorax_server/models/flash_cohere.py @@ -123,6 +123,10 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS + @property + def default_traced_adapter_layers(self) -> List[str]: + return [Q_PROJ, V_PROJ] + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.model.layers) diff --git a/server/lorax_server/models/flash_dbrx.py b/server/lorax_server/models/flash_dbrx.py index a552b8917..c62bc599b 100644 --- a/server/lorax_server/models/flash_dbrx.py +++ b/server/lorax_server/models/flash_dbrx.py @@ -137,6 +137,10 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS + @property + def default_traced_adapter_layers(self) -> List[str]: + return [ATTN_WQKV] + def get_num_layers_for_type(self, layer_type: str) -> int: return len(self.model.model.layers) diff --git a/server/lorax_server/models/flash_gemma.py b/server/lorax_server/models/flash_gemma.py index a572aab1d..44b5c7d22 100644 --- a/server/lorax_server/models/flash_gemma.py +++ b/server/lorax_server/models/flash_gemma.py @@ -118,6 +118,10 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS + @property + def default_traced_adapter_layers(self) -> List[str]: + return [Q_PROJ, V_PROJ] + def get_num_layers_for_type(self, layer_type: str) -> int: return len(self.model.model.layers) diff --git a/server/lorax_server/models/flash_gpt2.py b/server/lorax_server/models/flash_gpt2.py index 73595c2e4..83cdb46b9 100644 --- a/server/lorax_server/models/flash_gpt2.py +++ b/server/lorax_server/models/flash_gpt2.py @@ -110,6 +110,10 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS + @property + def default_traced_adapter_layers(self) -> List[str]: + return [ATTN_C_ATTN] + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.transformer.h) diff --git a/server/lorax_server/models/flash_llama.py b/server/lorax_server/models/flash_llama.py index e3f55118a..3a29a434c 100644 --- a/server/lorax_server/models/flash_llama.py +++ b/server/lorax_server/models/flash_llama.py @@ -141,6 +141,10 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS + @property + def default_traced_adapter_layers(self) -> List[str]: + return [Q_PROJ, V_PROJ] + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.model.layers) diff --git a/server/lorax_server/models/flash_mistral.py b/server/lorax_server/models/flash_mistral.py index 7cc6525f0..030636b2f 100644 --- a/server/lorax_server/models/flash_mistral.py +++ b/server/lorax_server/models/flash_mistral.py @@ -127,6 +127,10 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS + @property + def default_traced_adapter_layers(self) -> List[str]: + return [Q_PROJ, V_PROJ] + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.model.layers) diff --git a/server/lorax_server/models/flash_mixtral.py b/server/lorax_server/models/flash_mixtral.py index 151e2613e..7e409c14c 100644 --- a/server/lorax_server/models/flash_mixtral.py +++ b/server/lorax_server/models/flash_mixtral.py @@ -131,6 +131,10 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS + @property + def default_traced_adapter_layers(self) -> List[str]: + return [ATTN_Q_PROJ, ATTN_V_PROJ] + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.model.layers) diff --git a/server/lorax_server/models/flash_phi.py b/server/lorax_server/models/flash_phi.py index 0d39b2c9f..a024b7e9e 100644 --- a/server/lorax_server/models/flash_phi.py +++ b/server/lorax_server/models/flash_phi.py @@ -127,6 +127,10 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS + @property + def default_traced_adapter_layers(self) -> List[str]: + return [ATTN_Q_PROJ, ATTN_V_PROJ] + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.model.layers) diff --git a/server/lorax_server/models/flash_phi3.py b/server/lorax_server/models/flash_phi3.py index cb4b47c24..4a03f6df1 100644 --- a/server/lorax_server/models/flash_phi3.py +++ b/server/lorax_server/models/flash_phi3.py @@ -129,6 +129,10 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS + @property + def default_traced_adapter_layers(self) -> List[str]: + return [QKV_PROJ] + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.model.layers) diff --git a/server/lorax_server/models/flash_qwen.py b/server/lorax_server/models/flash_qwen.py index 68274a042..3c8c180f3 100644 --- a/server/lorax_server/models/flash_qwen.py +++ b/server/lorax_server/models/flash_qwen.py @@ -114,6 +114,10 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS + @property + def default_traced_adapter_layers(self) -> List[str]: + return [ATTN_C_ATTN] + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.transformer.h) diff --git a/server/lorax_server/models/flash_qwen2.py b/server/lorax_server/models/flash_qwen2.py index 4244b2112..256f3ee87 100644 --- a/server/lorax_server/models/flash_qwen2.py +++ b/server/lorax_server/models/flash_qwen2.py @@ -143,6 +143,10 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS + @property + def default_traced_adapter_layers(self) -> List[str]: + return [ATTN_Q_PROJ, ATTN_V_PROJ] + def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.model.layers) diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 17eb30705..ab654dab2 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -58,7 +58,7 @@ def __init__( # This may be set to False in the subclass constructor self.dynamic_adapter_loading_enabled = dynamic_adapter_loading_enabled - self.batched_lora_weights: Dict[str, LayerAdapterWeights] = defaultdict(LayerAdapterWeights) + self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(LayerAdapterWeights) self.target_to_layer = self.adapter_target_to_layer() self.loaded_adapters = set() self.static_adapter_id = adapter_id @@ -153,6 +153,10 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: def adapter_layers(self) -> List[str]: return [] + @property + def default_traced_adapter_layers(self) -> List[str]: + return [] + def get_num_layers_for_type(self, layer_type: str) -> int: return 0 @@ -162,7 +166,7 @@ def is_row_parallel(self, layer_type: str) -> bool: @property def max_speculative_tokens(self) -> int: return max( - [layer_weights.max_speculative_tokens for layer_weights in self.batched_lora_weights.values()], + [weights.max_speculative_tokens for weights in self.layer_to_adapter_weights.values()], default=0, ) @@ -224,8 +228,8 @@ def load_adapter( if adapter_weights is None: continue - batched_weights = self.batched_lora_weights[layer_name] - batched_weights.add_adapter(adapter_index, adapter_weights) + layer_weights = self.layer_to_adapter_weights[layer_name] + layer_weights.add_adapter(adapter_index, adapter_weights) if len(unused_weight_names) > 0: logger.warning(f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}") @@ -273,7 +277,7 @@ def offload_adapter( ) for layer_name in self.adapter_layers: - if layer_name in self.batched_lora_weights: - self.batched_lora_weights[layer_name].remove_adapter(adapter_index) + if layer_name in self.layer_to_adapter_weights: + self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index) self.loaded_adapters.remove(adapter_index) diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index fd4d68e2a..a6e1f3c97 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -4,10 +4,11 @@ from dataclasses import dataclass from functools import lru_cache from statistics import median -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple import numpy as np import torch +from loguru import logger from torch import nn from tqdm import tqdm @@ -15,15 +16,16 @@ from lorax_server.adapters.lora import BatchLoraWeights, RankSegments from lorax_server.adapters.types import LORA from lorax_server.models.cache_manager import BLOCK_SIZE, get_cache_manager -from lorax_server.utils.sgmv import get_tmp_expand_size, get_tmp_tensors, use_cutlass_shrink +from lorax_server.utils.sgmv import BGMV_MAX_RANK if TYPE_CHECKING: from lorax_server.models.flash_causal_lm import FlashCausalLMBatch + from lorax_server.models.model import Model # TODO(travis): make this configurable by model / user MAX_BATCH_SIZE = 256 -MAX_RANK = 64 +MAX_RANK = BGMV_MAX_RANK SLOT_PAD_VALUE = -1 SEGMENT_PAD_VALUE = -1 @@ -38,6 +40,8 @@ CACHED_MAX_RANKS = [0, 8, 16, 32, 64] _allowed_ranks = set(CACHED_MAX_RANKS) +assert all([r <= BGMV_MAX_RANK for r in _allowed_ranks]), f"Invalid ranks: {_allowed_ranks}" + MAX_SAMPLES = 3 @@ -56,7 +60,7 @@ def get_cached_batch_size(batch_size: int) -> int: def pad_and_fill(dest: torch.Tensor, src: torch.Tensor, pad_value: int): - dest[: src.shape[0]] = src + dest[: src.shape[0]].copy_(src, non_blocking=True) dest[src.shape[0] :].fill_(pad_value) @@ -73,6 +77,7 @@ class GraphState: slots: torch.Tensor input_lengths: torch.Tensor adapter_data: AdapterBatchData + traced_adapter_layer_names: Set[str] @lru_cache(maxsize=1) @@ -95,8 +100,6 @@ def get_max_graph_state( slots = torch.full((MAX_BATCH_SIZE,), SLOT_PAD_VALUE, dtype=torch.int64, device=device) input_lengths = torch.ones((MAX_BATCH_SIZE,), dtype=torch.int32, device=device) - tmp_shrink, tmp_expand = get_tmp_tensors(MAX_BATCH_SIZE, MAX_RANK, device) - adapter_weight_data = {} for layer_name in adapter_layers: adapter_weight_data[layer_name] = BatchLoraWeights( @@ -106,14 +109,16 @@ def get_max_graph_state( rank_data={ MAX_RANK: RankSegments( rank=MAX_RANK, - tmp_shrink=tmp_shrink, - tmp_expand=tmp_expand, lora_a_ptr=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), lora_b_ptr=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), - segment_starts=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int32, device=device), - segment_ends=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int32, device=device), + indices=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), + segment_starts=None, + segment_ends=None, + tmp_shrink=None, + tmp_expand=None, ), }, + use_sgmv=False, # bgmv during decode ) return GraphState( @@ -130,7 +135,9 @@ def get_max_graph_state( segment_indices=[], ), data=adapter_weight_data, + prefill=False, ), + traced_adapter_layer_names=set(adapter_layers), ) @@ -159,6 +166,7 @@ def trace( memory_pool: Tuple[int, int], max_total_tokens: int, sliding_window_blocks: Optional[int] = None, + traced_adapter_layer_names: Optional[Set[str]] = None, ) -> "GraphWrapper": max_input_state = get_max_graph_state(device, adapter_layers, max_total_tokens, sliding_window_blocks) @@ -169,14 +177,12 @@ def trace( # But we need to investigate further. segment_size = next_pow_2(batch_size) + traced_adapter_layer_names = traced_adapter_layer_names or set() + adapter_weight_data = {} for layer_name, weight_data in max_input_state.adapter_data.data.items(): - tmp_expand_size = get_tmp_expand_size(segment_size) - - tmp_shrink = weight_data.rank_data[MAX_RANK].tmp_shrink - if use_cutlass_shrink(max_rank): - # cutlass shrink uses a custom temp buffer per rank - tmp_shrink = tmp_shrink[:tmp_expand_size] + if layer_name not in traced_adapter_layer_names: + continue adapter_weight_data[layer_name] = { LORA: BatchLoraWeights( @@ -187,17 +193,19 @@ def trace( { max_rank: RankSegments( rank=max_rank, - tmp_shrink=tmp_shrink, - tmp_expand=weight_data.rank_data[MAX_RANK].tmp_expand[:tmp_expand_size], lora_a_ptr=weight_data.rank_data[MAX_RANK].lora_a_ptr[:segment_size], lora_b_ptr=weight_data.rank_data[MAX_RANK].lora_b_ptr[:segment_size], - segment_starts=weight_data.rank_data[MAX_RANK].segment_starts[:segment_size], - segment_ends=weight_data.rank_data[MAX_RANK].segment_ends[:segment_size], + indices=weight_data.rank_data[MAX_RANK].indices[:batch_size], + segment_starts=None, + segment_ends=None, + tmp_shrink=None, + tmp_expand=None, ), } if max_rank > 0 else {} ), + use_sgmv=False, # bgmv during decode ) } @@ -215,7 +223,9 @@ def trace( segment_indices=max_input_state.adapter_data.meta.segment_indices, ), data=adapter_weight_data, + prefill=False, ), + traced_adapter_layer_names=traced_adapter_layer_names, ) torch.cuda.synchronize(device) @@ -261,28 +271,21 @@ def forward( self.input_state.block_tables[: block_tables.shape[0], : block_tables.shape[1]] = block_tables for layer_name, weight_data in self.input_state.adapter_data.data.items(): + # TODO(travis): generalize this to support other adapter types lora_data = weight_data[LORA] if layer_name not in adapter_data.data: # zero out all the segments for rank_data in lora_data.rank_data.values(): - rank_data.segment_starts.fill_(SEGMENT_PAD_VALUE) - rank_data.segment_ends.fill_(SEGMENT_PAD_VALUE) + rank_data.indices.fill_(SEGMENT_PAD_VALUE) continue - source_data = adapter_data.data[layer_name] + source_data = adapter_data.data[layer_name][LORA] dest_data = lora_data for rank, source_rank_data in source_data.rank_data.items(): dest_rank_data = dest_data.rank_data[rank] - pad_and_fill(dest_rank_data.lora_a_ptr, source_rank_data.lora_a_ptr, 0) pad_and_fill(dest_rank_data.lora_b_ptr, source_rank_data.lora_b_ptr, 0) - - pad_and_fill( - dest_rank_data.segment_starts, - source_rank_data.segment_starts, - SEGMENT_PAD_VALUE, - ) - pad_and_fill(dest_rank_data.segment_ends, source_rank_data.segment_ends, SEGMENT_PAD_VALUE) + pad_and_fill(dest_rank_data.indices, source_rank_data.indices, SEGMENT_PAD_VALUE) self.graph.replay() @@ -295,17 +298,19 @@ def __call__(self, *args, **kwargs): class GraphCache: def __init__( self, - model: nn.Module, + model: "Model", device: torch.device, adapter_layers: List[str], + default_traced_adapter_layers: List[str], max_total_tokens: int, sliding_window_blocks: Optional[int] = None, ): self.model = model self.device = device self.adapter_layers = tuple(adapter_layers) + self.default_traced_adapter_layers = set(default_traced_adapter_layers) self.memory_pool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None - self.cache = {} + self.cache: Dict[Tuple[int, int], GraphWrapper] = {} self.max_total_tokens = max_total_tokens self.sliding_window_blocks = sliding_window_blocks @@ -322,7 +327,7 @@ def can_use_graph( max_s = batch.max_seqlen # Only allow LoRA adapters for now - adapter_keys = set(adapter_data.data.keys()) + adapter_keys = set(adapter_data.adapter_keys()) # TODO(travis): allow using CUDA graphs with multi-rank batches return ( @@ -358,6 +363,7 @@ def get_estimated_cache_memory(self) -> int: pool, self.max_total_tokens, self.sliding_window_blocks, + self.adapter_layers, # estimate memory assuming all adapters are traced ) tmp_cache[key] = graph pool = graph.memory_pool @@ -397,6 +403,7 @@ def warmup(self): pool, self.max_total_tokens, self.sliding_window_blocks, + self.default_traced_adapter_layers, ) self.cache[key] = graph pool = graph.memory_pool @@ -420,17 +427,27 @@ def forward( max_rank = adapter_data.max_rank key = (batch_size, max_rank) - if key not in self.cache: - self.cache[key] = GraphWrapper.trace( + graph = self.cache.get(key) + if graph is None or not graph.input_state.traced_adapter_layer_names.issuperset(adapter_data.layer_names()): + logger.info( + "Retrace graph with new adapter layers: {} -> {}", + graph.input_state.traced_adapter_layer_names, + adapter_data.layer_names(), + ) + graph = GraphWrapper.trace( self.model, self.device, self.adapter_layers, batch_size, max_rank, self.memory_pool, + self.max_total_tokens, + self.sliding_window_blocks, + adapter_data.layer_names(), ) + self.cache[key] = graph - output_states = self.cache[key].forward( + output_states = graph.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index a1bb3e3c8..5fda02bd9 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -11,6 +11,8 @@ from lorax_server.adapters.types import LORA, MEDUSA from lorax_server.utils.gptq.quant_linear import QuantLinear from lorax_server.utils.sgmv import ( + add_lora_a_bgmv, + add_lora_b_bgmv, has_sgmv, lora_a_sgmv_cutlass, lora_b_sgmv_cutlass, @@ -531,29 +533,54 @@ def forward_layer_type( for r, rank_segments in data.rank_data.items(): lora_a_ptr = rank_segments.lora_a_ptr lora_b_ptr = rank_segments.lora_b_ptr - if lora_a_ptr is not None and lora_b_ptr is not None: - v = lora_a_sgmv_cutlass( - input, - rank_segments.tmp_shrink, - lora_a_ptr, - rank_segments.segment_starts, - rank_segments.segment_ends, - self.layer_id, - r, - ) - if self.process_group.size() > 1: - v = self.collect_lora_a(v) - - lora_b_sgmv_cutlass( - proj, - v, - rank_segments.tmp_expand, - lora_b_ptr, - rank_segments.segment_starts, - rank_segments.segment_ends, - self.layer_id, - ) + if data.use_sgmv: + # Use SGMV for prefill + if lora_a_ptr is not None and lora_b_ptr is not None: + v = lora_a_sgmv_cutlass( + input, + rank_segments.tmp_shrink, + lora_a_ptr, + rank_segments.segment_starts, + rank_segments.segment_ends, + self.layer_id, + r, + ) + + if self.process_group.size() > 1: + v = self.collect_lora_a(v) + + lora_b_sgmv_cutlass( + proj, + v, + rank_segments.tmp_expand, + lora_b_ptr, + rank_segments.segment_starts, + rank_segments.segment_ends, + self.layer_id, + ) + else: + # Use BGMV for decode + if lora_a_ptr is not None and lora_b_ptr is not None: + v = torch.zeros((input.size(0), r), dtype=input.dtype, device=input.device) + add_lora_a_bgmv( + v, + input, + lora_a_ptr, + rank_segments.indices, + self.layer_id, + ) + + if self.process_group.size() > 1: + v = self.collect_lora_a(v) + + add_lora_b_bgmv( + proj, + v, + lora_b_ptr, + rank_segments.indices, + self.layer_id, + ) if end_idx - start_idx != result.shape[1]: result[:, start_idx:end_idx] += proj diff --git a/server/lorax_server/utils/sgmv.py b/server/lorax_server/utils/sgmv.py index 3e565d990..2a4701d09 100644 --- a/server/lorax_server/utils/sgmv.py +++ b/server/lorax_server/utils/sgmv.py @@ -20,6 +20,7 @@ MIN_RANK_CUSTOM = 16 MAX_RANK_CUSTOM = 128 SGMV_BLOCK_SIZE = 16 +BGMV_MAX_RANK = 64 def has_sgmv() -> bool: diff --git a/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h b/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h index 460eb2735..8b53a0cfb 100644 --- a/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h +++ b/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h @@ -1,34 +1,75 @@ #pragma once template -void bgmv_kernel(T* __restrict__ Y, const T* __restrict__ X, - T** __restrict__ W, - const int64_t* __restrict__ indicies, int64_t y_offset, +void bgmv_kernel(T *__restrict__ Y, const T *__restrict__ X, + T **__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, int64_t full_y_size, int64_t batch_size, int64_t layer_idx, float scale); // clang-format off #define FOR_BGMV_WIDE(f, T, narrow) \ + f(T, narrow, 128) \ + f(T, narrow, 256) \ + f(T, narrow, 512) \ + f(T, narrow, 640) \ f(T, narrow, 768) \ f(T, narrow, 1024) \ + f(T, narrow, 1152) \ + f(T, narrow, 1280) \ + f(T, narrow, 1536) \ + f(T, narrow, 1728) \ + f(T, narrow, 1792) \ f(T, narrow, 2048) \ + f(T, narrow, 2304) \ f(T, narrow, 2560) \ + f(T, narrow, 2752) \ + f(T, narrow, 2816) \ f(T, narrow, 3072) \ + f(T, narrow, 3456) \ + f(T, narrow, 3584) \ f(T, narrow, 4096) \ + f(T, narrow, 4608) \ f(T, narrow, 5120) \ + f(T, narrow, 5504) \ + f(T, narrow, 5632) \ + f(T, narrow, 6144) \ + f(T, narrow, 6848) \ + f(T, narrow, 6912) \ f(T, narrow, 7168) \ f(T, narrow, 8192) \ f(T, narrow, 9216) \ f(T, narrow, 10240) \ f(T, narrow, 11008) \ f(T, narrow, 12288) \ + f(T, narrow, 13696) \ f(T, narrow, 13824) \ + f(T, narrow, 14336) \ + f(T, narrow, 15360) \ f(T, narrow, 16384) \ f(T, narrow, 20480) \ + f(T, narrow, 22016) \ + f(T, narrow, 24576) \ + f(T, narrow, 27392) \ f(T, narrow, 28672) \ + f(T, narrow, 32000) \ + f(T, narrow, 32256) \ + f(T, narrow, 32512) \ + f(T, narrow, 32768) \ + f(T, narrow, 33024) \ f(T, narrow, 36864) \ + f(T, narrow, 43264) \ f(T, narrow, 49152) \ + f(T, narrow, 64000) \ + f(T, narrow, 64256) \ + f(T, narrow, 64512) \ + f(T, narrow, 102400) \ + f(T, narrow, 102656) \ + f(T, narrow, 102912) \ + f(T, narrow, 128000) \ + f(T, narrow, 128256) \ + f(T, narrow, 128512) \ #define FOR_BGMV_WIDE_NARROW(f, T) \ FOR_BGMV_WIDE(f, T, 8) \ diff --git a/server/punica_kernels/punica_kernels/punica_ops.cc b/server/punica_kernels/punica_kernels/punica_ops.cc index 3722fb9a8..fecef5357 100644 --- a/server/punica_kernels/punica_kernels/punica_ops.cc +++ b/server/punica_kernels/punica_kernels/punica_ops.cc @@ -11,23 +11,27 @@ #include "sgmv/sgmv.h" #include "sgmv_flashinfer/sgmv_config.h" -namespace { - -//====== utils ====== - -inline void check_shape(const torch::Tensor& a, const torch::Tensor& b, - const char* a_name, const char* b_name) { - TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", - a.dim(), " vs ", b.dim()); - for (int i = 0; i < a.dim(); ++i) { - TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, - ".size(", i, ")"); +namespace +{ + + //====== utils ====== + + inline void check_shape(const torch::Tensor &a, const torch::Tensor &b, + const char *a_name, const char *b_name) + { + TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", + a.dim(), " vs ", b.dim()); + for (int i = 0; i < a.dim(); ++i) + { + TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, + ".size(", i, ")"); + } } -} -inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { - return (uint32_t(a) << 16) | uint32_t(b); -} + inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) + { + return (uint64_t(a) << 32) | uint64_t(b); + } #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") @@ -49,19 +53,21 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { #define CHECK_GE(a, b) \ TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b) -//====== dispatch pytorch dtype ====== + //====== dispatch pytorch dtype ====== #define _DISPATCH_SWITCH(cond, ...) \ [&]() -> bool { \ - switch (cond) { \ + switch (cond) \ + { \ __VA_ARGS__ \ - default: \ - return false; \ + default: \ + return false; \ } \ }() #define _DISPATCH_DTYPE_CASE(enum_type, c_type_, ...) \ - case enum_type: { \ + case enum_type: \ + { \ using c_type = c_type_; \ return __VA_ARGS__(); \ } @@ -73,363 +79,372 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { #define DISPATCH_TORCH_DTYPE(scalar_type, ...) \ _DISPATCH_SWITCH(scalar_type, _DISPATCH_DTYPE_CASES(__VA_ARGS__)) -//====== flashinfer ====== - -void batch_prefill(torch::Tensor o, torch::Tensor q, torch::Tensor qo_indptr, - torch::Tensor kv_ptrs, torch::Tensor kv_indptr, - torch::Tensor last_page_offset, torch::Tensor tmpbuf, - int num_layers, int layer_idx, int num_kv_heads, - int page_size) { - CHECK_INPUT(o); - CHECK_INPUT(q); - CHECK_INPUT(qo_indptr); - CHECK_INPUT(kv_ptrs); - CHECK_INPUT(kv_indptr); - CHECK_INPUT(last_page_offset); - - CHECK_DIM(3, o); // [qo_indptr[-1], N, D] - CHECK_DIM(3, q); // [qo_indptr[-1], N, D] - CHECK_DIM(1, qo_indptr); // [B+1] - CHECK_DIM(1, kv_ptrs); // [kv_indptr[-1]] ptr to a [L, 2, N, P, D] - CHECK_DIM(1, kv_indptr); // [B+1] - CHECK_DIM(1, last_page_offset); // [B] - - int batch_size = static_cast(last_page_offset.size(0)); - int num_qo_heads = static_cast(o.size(1)); - int head_dim = static_cast(o.size(2)); - int group_size = num_qo_heads / num_kv_heads; - CHECK_SHAPE(o, q); - CHECK_EQ(num_qo_heads, group_size * num_kv_heads); - CHECK_EQ(qo_indptr.size(0), batch_size + 1); - CHECK_EQ(kv_indptr.size(0), batch_size + 1); - CHECK_GE(tmpbuf.nbytes(), sizeof(int32_t) * (4 * batch_size + 1)); - CHECK_GE(tmpbuf.nbytes(), 64 << 20); - - bool ok = DISPATCH_TORCH_DTYPE(q.scalar_type(), [&] { - return FlashInferBatchPrefillKernel( - static_cast(o.data_ptr()), static_cast(q.data_ptr()), - qo_indptr.data_ptr(), - reinterpret_cast(kv_ptrs.data_ptr()), - kv_indptr.data_ptr(), last_page_offset.data_ptr(), - tmpbuf.data_ptr(), head_dim, num_layers, layer_idx, group_size, - num_kv_heads, page_size, batch_size); - }); - TORCH_CHECK(ok, "No suitable kernel.", " dtype=", q.scalar_type(), - " page_size=", page_size, " group_size=", group_size, - " head_dim=", head_dim); -} - -void batch_decode(torch::Tensor o, torch::Tensor q, torch::Tensor kv_ptrs, - torch::Tensor kv_indptr, torch::Tensor last_page_offset, - torch::Tensor tmpbuf, int num_layers, int layer_idx, - int num_kv_heads, int page_size) { - CHECK_INPUT(o); - CHECK_INPUT(q); - CHECK_INPUT(kv_ptrs); - CHECK_INPUT(kv_indptr); - CHECK_INPUT(last_page_offset); - - CHECK_DIM(3, o); // [B, N, D] - CHECK_DIM(3, q); // [B, N, D] - CHECK_DIM(1, kv_ptrs); // [kv_indptr[-1]] ptr to a [L, 2, N, P, D] - CHECK_DIM(1, kv_indptr); // [B+1] - CHECK_DIM(1, last_page_offset); // [B] - - int batch_size = static_cast(o.size(0)); - int num_qo_heads = static_cast(o.size(1)); - int head_dim = static_cast(o.size(2)); - int group_size = num_qo_heads / num_kv_heads; - CHECK_SHAPE(o, q); - CHECK_EQ(num_qo_heads, group_size * num_kv_heads); - CHECK_EQ(kv_indptr.size(0), batch_size + 1); - CHECK_EQ(last_page_offset.size(0), batch_size); - CHECK_GE(tmpbuf.nbytes(), sizeof(int32_t) * (4 * batch_size + 1)); - CHECK_GE(tmpbuf.nbytes(), 64 << 20); - - bool ok = DISPATCH_TORCH_DTYPE(q.scalar_type(), [&] { - return FlashInferBatchDecodeKernel( - static_cast(o.data_ptr()), static_cast(q.data_ptr()), - reinterpret_cast(kv_ptrs.data_ptr()), - kv_indptr.data_ptr(), last_page_offset.data_ptr(), - tmpbuf.data_ptr(), head_dim, num_layers, layer_idx, group_size, - num_kv_heads, page_size, batch_size); - }); - TORCH_CHECK(ok, "No suitable kernel.", " dtype=", q.scalar_type(), - " page_size=", page_size, " group_size=", group_size, - " head_dim=", head_dim); -} - -void init_kv(torch::Tensor kv_ptrs, torch::Tensor kv_indptr, - torch::Tensor last_page_offset, torch::Tensor k, torch::Tensor v, - torch::Tensor seqlen_indptr, int num_layers, int layer_idx, - int num_kv_heads, int page_size) { - CHECK_INPUT(kv_ptrs); - CHECK_INPUT(kv_indptr); - CHECK_INPUT(last_page_offset); - CHECK_INPUT(k); - CHECK_INPUT(v); - CHECK_INPUT(seqlen_indptr); - - CHECK_DIM(1, kv_ptrs); // [kv_indptr[-1]] ptr to a [L, 2, N, P, D] - CHECK_DIM(1, kv_indptr); // [B+1] - CHECK_DIM(1, last_page_offset); // [B] - CHECK_DIM(3, k); // [sum(seqlen_i), N, D] - CHECK_DIM(3, v); // [sum(seqlen_i), N, D] - CHECK_DIM(1, seqlen_indptr); // [B+1] - - int head_dim = static_cast(k.size(2)); - int batch_size = static_cast(last_page_offset.size(0)); - CHECK_EQ(kv_indptr.size(0), batch_size + 1); - CHECK_EQ(seqlen_indptr.size(0), batch_size + 1); - CHECK_SHAPE(k, v); - -#define CASE(dim, _) \ - case dim: \ - FlashInferInitKvKernel( \ - reinterpret_cast(kv_ptrs.data_ptr()), \ - kv_indptr.data_ptr(), last_page_offset.data_ptr(), \ - static_cast(k.data_ptr()), \ - static_cast(v.data_ptr()), seqlen_indptr.data_ptr(), \ - num_layers, layer_idx, num_kv_heads, page_size, batch_size); \ + //====== flashinfer ====== + + void batch_prefill(torch::Tensor o, torch::Tensor q, torch::Tensor qo_indptr, + torch::Tensor kv_ptrs, torch::Tensor kv_indptr, + torch::Tensor last_page_offset, torch::Tensor tmpbuf, + int num_layers, int layer_idx, int num_kv_heads, + int page_size) + { + CHECK_INPUT(o); + CHECK_INPUT(q); + CHECK_INPUT(qo_indptr); + CHECK_INPUT(kv_ptrs); + CHECK_INPUT(kv_indptr); + CHECK_INPUT(last_page_offset); + + CHECK_DIM(3, o); // [qo_indptr[-1], N, D] + CHECK_DIM(3, q); // [qo_indptr[-1], N, D] + CHECK_DIM(1, qo_indptr); // [B+1] + CHECK_DIM(1, kv_ptrs); // [kv_indptr[-1]] ptr to a [L, 2, N, P, D] + CHECK_DIM(1, kv_indptr); // [B+1] + CHECK_DIM(1, last_page_offset); // [B] + + int batch_size = static_cast(last_page_offset.size(0)); + int num_qo_heads = static_cast(o.size(1)); + int head_dim = static_cast(o.size(2)); + int group_size = num_qo_heads / num_kv_heads; + CHECK_SHAPE(o, q); + CHECK_EQ(num_qo_heads, group_size * num_kv_heads); + CHECK_EQ(qo_indptr.size(0), batch_size + 1); + CHECK_EQ(kv_indptr.size(0), batch_size + 1); + CHECK_GE(tmpbuf.nbytes(), sizeof(int32_t) * (4 * batch_size + 1)); + CHECK_GE(tmpbuf.nbytes(), 64 << 20); + + bool ok = DISPATCH_TORCH_DTYPE(q.scalar_type(), [&] + { return FlashInferBatchPrefillKernel( + static_cast(o.data_ptr()), static_cast(q.data_ptr()), + qo_indptr.data_ptr(), + reinterpret_cast(kv_ptrs.data_ptr()), + kv_indptr.data_ptr(), last_page_offset.data_ptr(), + tmpbuf.data_ptr(), head_dim, num_layers, layer_idx, group_size, + num_kv_heads, page_size, batch_size); }); + TORCH_CHECK(ok, "No suitable kernel.", " dtype=", q.scalar_type(), + " page_size=", page_size, " group_size=", group_size, + " head_dim=", head_dim); + } + + void batch_decode(torch::Tensor o, torch::Tensor q, torch::Tensor kv_ptrs, + torch::Tensor kv_indptr, torch::Tensor last_page_offset, + torch::Tensor tmpbuf, int num_layers, int layer_idx, + int num_kv_heads, int page_size) + { + CHECK_INPUT(o); + CHECK_INPUT(q); + CHECK_INPUT(kv_ptrs); + CHECK_INPUT(kv_indptr); + CHECK_INPUT(last_page_offset); + + CHECK_DIM(3, o); // [B, N, D] + CHECK_DIM(3, q); // [B, N, D] + CHECK_DIM(1, kv_ptrs); // [kv_indptr[-1]] ptr to a [L, 2, N, P, D] + CHECK_DIM(1, kv_indptr); // [B+1] + CHECK_DIM(1, last_page_offset); // [B] + + int batch_size = static_cast(o.size(0)); + int num_qo_heads = static_cast(o.size(1)); + int head_dim = static_cast(o.size(2)); + int group_size = num_qo_heads / num_kv_heads; + CHECK_SHAPE(o, q); + CHECK_EQ(num_qo_heads, group_size * num_kv_heads); + CHECK_EQ(kv_indptr.size(0), batch_size + 1); + CHECK_EQ(last_page_offset.size(0), batch_size); + CHECK_GE(tmpbuf.nbytes(), sizeof(int32_t) * (4 * batch_size + 1)); + CHECK_GE(tmpbuf.nbytes(), 64 << 20); + + bool ok = DISPATCH_TORCH_DTYPE(q.scalar_type(), [&] + { return FlashInferBatchDecodeKernel( + static_cast(o.data_ptr()), static_cast(q.data_ptr()), + reinterpret_cast(kv_ptrs.data_ptr()), + kv_indptr.data_ptr(), last_page_offset.data_ptr(), + tmpbuf.data_ptr(), head_dim, num_layers, layer_idx, group_size, + num_kv_heads, page_size, batch_size); }); + TORCH_CHECK(ok, "No suitable kernel.", " dtype=", q.scalar_type(), + " page_size=", page_size, " group_size=", group_size, + " head_dim=", head_dim); + } + + void init_kv(torch::Tensor kv_ptrs, torch::Tensor kv_indptr, + torch::Tensor last_page_offset, torch::Tensor k, torch::Tensor v, + torch::Tensor seqlen_indptr, int num_layers, int layer_idx, + int num_kv_heads, int page_size) + { + CHECK_INPUT(kv_ptrs); + CHECK_INPUT(kv_indptr); + CHECK_INPUT(last_page_offset); + CHECK_INPUT(k); + CHECK_INPUT(v); + CHECK_INPUT(seqlen_indptr); + + CHECK_DIM(1, kv_ptrs); // [kv_indptr[-1]] ptr to a [L, 2, N, P, D] + CHECK_DIM(1, kv_indptr); // [B+1] + CHECK_DIM(1, last_page_offset); // [B] + CHECK_DIM(3, k); // [sum(seqlen_i), N, D] + CHECK_DIM(3, v); // [sum(seqlen_i), N, D] + CHECK_DIM(1, seqlen_indptr); // [B+1] + + int head_dim = static_cast(k.size(2)); + int batch_size = static_cast(last_page_offset.size(0)); + CHECK_EQ(kv_indptr.size(0), batch_size + 1); + CHECK_EQ(seqlen_indptr.size(0), batch_size + 1); + CHECK_SHAPE(k, v); + +#define CASE(dim, _) \ + case dim: \ + FlashInferInitKvKernel( \ + reinterpret_cast(kv_ptrs.data_ptr()), \ + kv_indptr.data_ptr(), last_page_offset.data_ptr(), \ + static_cast(k.data_ptr()), \ + static_cast(v.data_ptr()), seqlen_indptr.data_ptr(), \ + num_layers, layer_idx, num_kv_heads, page_size, batch_size); \ return true; - bool ok = DISPATCH_TORCH_DTYPE(k.scalar_type(), [&] { + bool ok = DISPATCH_TORCH_DTYPE(k.scalar_type(), [&] + { switch (head_dim) { FOR_FlashInferBatchDecode_D(CASE); default: return false; - } - }); - TORCH_CHECK(ok, "No suitable kernel.", " dtype=", k.scalar_type(), - " head_dim=", head_dim); + } }); + TORCH_CHECK(ok, "No suitable kernel.", " dtype=", k.scalar_type(), + " head_dim=", head_dim); #undef CASE -} + } -void append_kv(torch::Tensor kv_ptrs, torch::Tensor kv_indptr, - torch::Tensor last_page_offset, torch::Tensor k, torch::Tensor v, - int num_layers, int layer_idx, int num_kv_heads, int page_size) { - CHECK_INPUT(kv_ptrs); - CHECK_INPUT(kv_indptr); - CHECK_INPUT(last_page_offset); - CHECK_INPUT(k); - CHECK_INPUT(v); - - CHECK_DIM(1, kv_ptrs); // [kv_indptr[-1]] ptr to a [L, 2, N, P, D] - CHECK_DIM(1, kv_indptr); // [B+1] - CHECK_DIM(1, last_page_offset); // [B] - CHECK_DIM(3, k); // [B, N, D] - CHECK_DIM(3, v); // [B, N, D] - - int head_dim = static_cast(k.size(2)); - int batch_size = static_cast(k.size(0)); - CHECK_EQ(kv_indptr.size(0), batch_size + 1); - CHECK_EQ(last_page_offset.size(0), batch_size); - CHECK_SHAPE(k, v); + void append_kv(torch::Tensor kv_ptrs, torch::Tensor kv_indptr, + torch::Tensor last_page_offset, torch::Tensor k, torch::Tensor v, + int num_layers, int layer_idx, int num_kv_heads, int page_size) + { + CHECK_INPUT(kv_ptrs); + CHECK_INPUT(kv_indptr); + CHECK_INPUT(last_page_offset); + CHECK_INPUT(k); + CHECK_INPUT(v); + + CHECK_DIM(1, kv_ptrs); // [kv_indptr[-1]] ptr to a [L, 2, N, P, D] + CHECK_DIM(1, kv_indptr); // [B+1] + CHECK_DIM(1, last_page_offset); // [B] + CHECK_DIM(3, k); // [B, N, D] + CHECK_DIM(3, v); // [B, N, D] + + int head_dim = static_cast(k.size(2)); + int batch_size = static_cast(k.size(0)); + CHECK_EQ(kv_indptr.size(0), batch_size + 1); + CHECK_EQ(last_page_offset.size(0), batch_size); + CHECK_SHAPE(k, v); #define CASE(dim, _) \ case dim: \ FlashInferAppendKvKernel( \ - reinterpret_cast(kv_ptrs.data_ptr()), \ + reinterpret_cast(kv_ptrs.data_ptr()), \ kv_indptr.data_ptr(), last_page_offset.data_ptr(), \ - static_cast(k.data_ptr()), \ - static_cast(v.data_ptr()), num_layers, layer_idx, \ + static_cast(k.data_ptr()), \ + static_cast(v.data_ptr()), num_layers, layer_idx, \ num_kv_heads, page_size, batch_size); \ return true; - bool ok = DISPATCH_TORCH_DTYPE(k.scalar_type(), [&] { + bool ok = DISPATCH_TORCH_DTYPE(k.scalar_type(), [&] + { switch (head_dim) { FOR_FlashInferBatchDecode_D(CASE); default: return false; - } - }); - TORCH_CHECK(ok, "No suitable kernel.", " dtype=", k.scalar_type(), - " head_dim=", head_dim); + } }); + TORCH_CHECK(ok, "No suitable kernel.", " dtype=", k.scalar_type(), + " head_dim=", head_dim); #undef CASE -} - -//====== bgmv ====== - -template -inline bool launch_bgmv_kernel(T* Y, const T* X, T** W, - const int64_t* lora_indices, - uint16_t in_features, uint16_t out_features, - int64_t y_offset, int64_t full_y_size, - int64_t batch_size, - int64_t layer_idx, float scale) { - switch (pack_u16(in_features, out_features)) { -#define CASE_ONESIDE(_T, feat_in, feat_out) \ - case pack_u16(feat_in, feat_out): \ - bgmv_kernel(Y, X, W, lora_indices, y_offset, \ - full_y_size, batch_size, \ - layer_idx, scale); \ + } + + //====== bgmv ====== + + template + inline bool launch_bgmv_kernel(T *Y, const T *X, T **W, + const int64_t *lora_indices, + uint16_t in_features, uint16_t out_features, + int64_t y_offset, int64_t full_y_size, + int64_t batch_size, + int64_t layer_idx, float scale) + { + switch (pack_u32(in_features, out_features)) + { +#define CASE_ONESIDE(_T, feat_in, feat_out) \ + case pack_u32(feat_in, feat_out): \ + bgmv_kernel(Y, X, W, lora_indices, y_offset, \ + full_y_size, batch_size, \ + layer_idx, scale); \ break; #define CASE(_T, narrow, wide) \ CASE_ONESIDE(T, narrow, wide) \ CASE_ONESIDE(T, wide, narrow) - FOR_BGMV_WIDE_NARROW(CASE, _) + FOR_BGMV_WIDE_NARROW(CASE, _) #undef CASE #undef CASE_ONESIDE default: return false; + } + + return true; } - return true; -} - -void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr, - torch::Tensor indicies, int64_t layer_idx, float scale) { - CHECK_INPUT(y); - CHECK_INPUT(x); - CHECK_INPUT(w_ptr); - CHECK_INPUT(indicies); - - CHECK_DIM(2, y); - CHECK_DIM(2, x); - CHECK_DIM(1, w_ptr); - CHECK_DIM(1, indicies); - - int64_t B = x.size(0); - int64_t h_in = x.size(1); - int64_t h_out = y.size(1); - CHECK_EQ(indicies.size(0), x.size(0)); - CHECK_EQ(y.size(0), x.size(0)); - bool ok = false; - if (h_in < 65536 && h_out < 65536) { - switch (x.scalar_type()) { + void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr, + torch::Tensor indicies, int64_t layer_idx, float scale) + { + CHECK_INPUT(y); + CHECK_INPUT(x); + CHECK_INPUT(w_ptr); + CHECK_INPUT(indicies); + + CHECK_DIM(2, y); + CHECK_DIM(2, x); + CHECK_DIM(1, w_ptr); + CHECK_DIM(1, indicies); + + int64_t B = x.size(0); + int64_t h_in = x.size(1); + int64_t h_out = y.size(1); + CHECK_EQ(indicies.size(0), x.size(0)); + CHECK_EQ(y.size(0), x.size(0)); + bool ok = false; + if (h_in < 65536 && h_out < 65536) + { + switch (x.scalar_type()) + { case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w_ptr.data_ptr()), + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w_ptr.data_ptr()), indicies.data_ptr(), h_in, h_out, 0, h_out, B, layer_idx, scale); break; case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w_ptr.data_ptr()), + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w_ptr.data_ptr()), indicies.data_ptr(), h_in, h_out, 0, h_out, B, layer_idx, scale); break; default: break; + } } + TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, + " dtype=", x.scalar_type()); } - TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, - " dtype=", x.scalar_type()); -} - -//====== sgmv ====== - -void dispatch_sgmv_cutlass(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr, - torch::Tensor s_start, torch::Tensor s_end, - torch::Tensor tmp, int layer_idx) { - CHECK_INPUT(y); - CHECK_INPUT(x); - CHECK_INPUT(w_ptr); - CHECK_INPUT(s_start); - CHECK_INPUT(s_end); - CHECK_INPUT(tmp); - - CHECK_DIM(2, y); - CHECK_DIM(2, x); - CHECK_DIM(1, w_ptr); - CHECK_DIM(1, s_start); - CHECK_DIM(1, s_end); - CHECK_DIM(1, tmp); - - int num_problems = s_start.size(0); - int d_in = x.size(1); - int d_out = y.size(1); - CHECK_EQ(tmp.size(0), static_cast(sgmv_tmp_size(num_problems))); - cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); - bool ok = DISPATCH_TORCH_DTYPE(x.scalar_type(), [&] { - return sgmv((c_type*)y.data_ptr(), (c_type*)x.data_ptr(), (c_type**)w_ptr.data_ptr(), - s_start.data_ptr(), s_end.data_ptr(), - tmp.data_ptr(), num_problems, d_in, d_out, - layer_idx, stream); - }); - TORCH_CHECK(ok, "No suitable kernel.", " dtype=", x.scalar_type()); -} - -void dispatch_sgmv_shrink(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr, - torch::Tensor s_start, torch::Tensor s_end, torch::Tensor tmp, int layer_idx) { - CHECK_INPUT(y); - CHECK_INPUT(x); - CHECK_INPUT(w_ptr); - CHECK_INPUT(s_start); - CHECK_INPUT(s_end); - CHECK_INPUT(tmp); - - CHECK_DIM(2, y); - CHECK_DIM(2, x); - CHECK_DIM(1, w_ptr); - CHECK_DIM(1, s_start); - CHECK_DIM(1, s_end); - CHECK_DIM(1, tmp); - - uint32_t num_problems = s_start.size(0); - uint32_t d_in = x.size(1); - uint32_t d_out = y.size(1); - CHECK_EQ(tmp.scalar_type(), at::ScalarType::Byte); - CHECK_EQ(tmp.size(0), 8 * 1024 * 1024); - cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); - -#define CASE(_T, D_OUT) \ - case D_OUT: \ - return sgmv_shrink( \ - (c_type*)y.data_ptr(), (c_type*)x.data_ptr(), \ - (c_type**)w_ptr.data_ptr(), s_start.data_ptr(), s_end.data_ptr(), \ + + //====== sgmv ====== + + void dispatch_sgmv_cutlass(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr, + torch::Tensor s_start, torch::Tensor s_end, + torch::Tensor tmp, int layer_idx) + { + CHECK_INPUT(y); + CHECK_INPUT(x); + CHECK_INPUT(w_ptr); + CHECK_INPUT(s_start); + CHECK_INPUT(s_end); + CHECK_INPUT(tmp); + + CHECK_DIM(2, y); + CHECK_DIM(2, x); + CHECK_DIM(1, w_ptr); + CHECK_DIM(1, s_start); + CHECK_DIM(1, s_end); + CHECK_DIM(1, tmp); + + int num_problems = s_start.size(0); + int d_in = x.size(1); + int d_out = y.size(1); + CHECK_EQ(tmp.size(0), static_cast(sgmv_tmp_size(num_problems))); + cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); + bool ok = DISPATCH_TORCH_DTYPE(x.scalar_type(), [&] + { return sgmv((c_type *)y.data_ptr(), (c_type *)x.data_ptr(), (c_type **)w_ptr.data_ptr(), + s_start.data_ptr(), s_end.data_ptr(), + tmp.data_ptr(), num_problems, d_in, d_out, + layer_idx, stream); }); + TORCH_CHECK(ok, "No suitable kernel.", " dtype=", x.scalar_type()); + } + + void dispatch_sgmv_shrink(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr, + torch::Tensor s_start, torch::Tensor s_end, torch::Tensor tmp, int layer_idx) + { + CHECK_INPUT(y); + CHECK_INPUT(x); + CHECK_INPUT(w_ptr); + CHECK_INPUT(s_start); + CHECK_INPUT(s_end); + CHECK_INPUT(tmp); + + CHECK_DIM(2, y); + CHECK_DIM(2, x); + CHECK_DIM(1, w_ptr); + CHECK_DIM(1, s_start); + CHECK_DIM(1, s_end); + CHECK_DIM(1, tmp); + + uint32_t num_problems = s_start.size(0); + uint32_t d_in = x.size(1); + uint32_t d_out = y.size(1); + CHECK_EQ(tmp.scalar_type(), at::ScalarType::Byte); + CHECK_EQ(tmp.size(0), 8 * 1024 * 1024); + cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); + +#define CASE(_T, D_OUT) \ + case D_OUT: \ + return sgmv_shrink( \ + (c_type *)y.data_ptr(), (c_type *)x.data_ptr(), \ + (c_type **)w_ptr.data_ptr(), s_start.data_ptr(), s_end.data_ptr(), \ tmp.data_ptr(), num_problems, d_in, layer_idx, stream); - bool ok = DISPATCH_TORCH_DTYPE(x.scalar_type(), [&] { + bool ok = DISPATCH_TORCH_DTYPE(x.scalar_type(), [&] + { switch (d_out) { FOR_SGMV_NARROW(CASE, c_type); default: return false; - } - }); + } }); #undef CASE - TORCH_CHECK(ok, "No suitable kernel.", " dtype=", x.scalar_type(), - " d_out=", d_out); -} - -//====== rms_norm ====== - -void dispatch_rms_norm(torch::Tensor output, torch::Tensor input, - torch::Tensor weight, float epsilon) { - CHECK_INPUT(output); - CHECK_INPUT(input); - CHECK_INPUT(weight); - - CHECK_DIM(2, input); - CHECK_DIM(1, weight); - CHECK_SHAPE(output, input); - CHECK_EQ(input.size(input.dim() - 1), weight.size(0)); - CHECK_EQ(input.scalar_type(), weight.scalar_type()); - CHECK_EQ(input.scalar_type(), output.scalar_type()); - - int rows = input.size(0); - int columns = input.size(1); - - bool ok = DISPATCH_TORCH_DTYPE(input.scalar_type(), [&] { - return rms_norm(static_cast(output.data_ptr()), - static_cast(input.data_ptr()), - static_cast(weight.data_ptr()), rows, - columns, epsilon); - }); + TORCH_CHECK(ok, "No suitable kernel.", " dtype=", x.scalar_type(), + " d_out=", d_out); + } - TORCH_CHECK(ok, "No suitable kernel.", " dtype=", input.scalar_type(), - " columns=", columns); -} + //====== rms_norm ====== + + void dispatch_rms_norm(torch::Tensor output, torch::Tensor input, + torch::Tensor weight, float epsilon) + { + CHECK_INPUT(output); + CHECK_INPUT(input); + CHECK_INPUT(weight); + + CHECK_DIM(2, input); + CHECK_DIM(1, weight); + CHECK_SHAPE(output, input); + CHECK_EQ(input.size(input.dim() - 1), weight.size(0)); + CHECK_EQ(input.scalar_type(), weight.scalar_type()); + CHECK_EQ(input.scalar_type(), output.scalar_type()); + + int rows = input.size(0); + int columns = input.size(1); + + bool ok = DISPATCH_TORCH_DTYPE(input.scalar_type(), [&] + { return rms_norm(static_cast(output.data_ptr()), + static_cast(input.data_ptr()), + static_cast(weight.data_ptr()), rows, + columns, epsilon); }); + + TORCH_CHECK(ok, "No suitable kernel.", " dtype=", input.scalar_type(), + " columns=", columns); + } -} // namespace +} // namespace //====== pybind ====== -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ m.def("batch_prefill", &batch_prefill, ""); m.def("batch_decode", &batch_decode, ""); m.def("init_kv", &init_kv, ""); diff --git a/server/tests/utils/test_lora.py b/server/tests/utils/test_lora.py index ea33ed8cd..852d7e359 100644 --- a/server/tests/utils/test_lora.py +++ b/server/tests/utils/test_lora.py @@ -11,17 +11,20 @@ from lorax_server.utils.sgmv import MIN_RANK_CUSTOM -@pytest.mark.parametrize("lora_ranks", [ - [8, 16], - [32, 64], -]) +@pytest.mark.parametrize( + "lora_ranks", + [ + [8, 16], + [32, 64], + ], +) def test_batched_lora_weights(lora_ranks: List[int]): # batch meta is hardcoded with this assumption below assert len(lora_ranks) == 2 batched_weights = LayerAdapterWeights() assert batched_weights.is_empty() - + h = 1024 for idx, lora_rank in enumerate(lora_ranks): weights = LoraWeights( @@ -31,9 +34,9 @@ def test_batched_lora_weights(lora_ranks: List[int]): ) assert weights.lora_a_r == lora_rank assert weights.lora_b_r == lora_rank - + batched_weights.add_adapter(idx, weights) - + assert not batched_weights.is_empty() assert len(batched_weights.adapter_weights) == 2 @@ -45,13 +48,13 @@ def test_batched_lora_weights(lora_ranks: List[int]): ) with mock.patch("lorax_server.adapters.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))): - data = batched_weights.get_data(meta).get(LORA) + data = batched_weights.get_data(meta, prefill=True).get(LORA) assert len(data.lora_a) == 2 assert data.lora_a.keys() == meta.adapter_set assert data.lora_a[0].shape == ((1, h, lora_ranks[0]) if lora_ranks[0] < MIN_RANK_CUSTOM else (1, lora_ranks[0], h)) assert data.lora_a[1].shape == ((1, h, lora_ranks[1]) if lora_ranks[1] < MIN_RANK_CUSTOM else (1, lora_ranks[1], h)) - + assert len(data.lora_b) == 2 assert data.lora_b.keys() == meta.adapter_set assert data.lora_b[0].shape == (1, lora_ranks[0], h)