From b98884fe00eedab5e834861fb45201c86371e835 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 11 Dec 2023 11:09:17 -0800 Subject: [PATCH] Fix Qwen tensor parallelism (#120) --- .../custom_modeling/flash_qwen_modeling.py | 15 +++-- server/lorax_server/models/flash_causal_lm.py | 38 +++++++++--- server/lorax_server/models/flash_qwen.py | 41 +++++++++++++ server/lorax_server/utils/layers.py | 6 +- server/lorax_server/utils/lora.py | 7 +-- server/lorax_server/utils/weights.py | 61 ++++++++++++++----- 6 files changed, 131 insertions(+), 37 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py index fe575ab63..1e4a69cca 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py @@ -142,8 +142,8 @@ def forward(self, hidden_states, residual=None): def load_attention(config, prefix, weights, layer_id): - base_layer = load_attention_multi(config, prefix, weights) - projection_size = config.kv_channels * config.num_attention_heads + projection_size = (config.hidden_size // config.num_attention_heads) * config.num_attention_heads + base_layer = load_attention_multi(config, prefix, weights, projection_size) return TensorParallelMultiAdapterLinear.load( base_layer, layer_id, [ATTN_C_ATTN], sizes=[ 3 * projection_size, @@ -151,10 +151,14 @@ def load_attention(config, prefix, weights, layer_id): ) -def load_attention_multi(config, prefix, weights): +def load_attention_multi(config, prefix, weights, projection_size): return TensorParallelColumnLinear.load_multi( config, - prefixes=[f"{prefix}.c_attn"], + prefixes=[ + (f"{prefix}.c_attn", (0, projection_size)), + (f"{prefix}.c_attn", (projection_size, projection_size)), + (f"{prefix}.c_attn", (2 * projection_size, projection_size)), + ], dim=0, weights=weights, bias=True, @@ -173,7 +177,8 @@ def __init__( self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - self.projection_size = config.kv_channels * config.num_attention_heads + self.projection_size = (self.head_size * config.num_attention_heads) // weights.process_group.size() + self.process_group = weights.process_group self.rotary_emb = PositionRotaryEmbedding.static( config=config, diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 65b14a21f..306c2cabf 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -32,6 +32,7 @@ from lorax_server.utils.dist import MEMORY_FRACTION from lorax_server.utils.lora import LM_HEAD, AdapterBatchData, AdapterBatchMetadata, BatchedLoraWeights, MergedLoraWeights from lorax_server.utils.segments import SegmentConcatBuilder, find_segments +from lorax_server.utils.weights import shard_on_dim tracer = trace.get_tracer(__name__) @@ -755,6 +756,27 @@ def load_adapter(self, adapter_id, adapter_source, adapter_index): self.adapter_id = adapter_id + def shard_lora_weights( + self, + weights_a: List[torch.Tensor], + weights_b: List[torch.Tensor], + layer_type: str, + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + # [hidden_size, r] + split_dim = 0 if self.is_row_parallel(layer_type) else 1 + weights_a = [ + shard_on_dim(w, dim=split_dim, process_group=self.process_group) + for w in weights_a + ] + + # [r, hidden_size] + weights_b = [ + shard_on_dim(w, dim=1, process_group=self.process_group) + for w in weights_b + ] + + return weights_a, weights_b + def load_batched_adapter_weights( self, module_map: Dict[str, Dict], @@ -795,7 +817,7 @@ def load_batched_adapter_weights( lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale q_lora_merged = MergedLoraWeights( - lora_a_list, lora_b_list, adapter_config, layer_type, self.process_group, self.is_row_parallel(layer_type), + *self.shard_lora_weights(lora_a_list, lora_b_list, layer_type), adapter_config, ) q_lora_weights = self.batched_lora_weights[layer_type] q_lora_weights.add_adapter(adapter_index, q_lora_merged) @@ -828,7 +850,6 @@ def batch_type(self) -> Type[FlashCausalLMBatch]: return FlashCausalLMBatch def warmup(self, batch: FlashCausalLMBatch): - torch.cuda.empty_cache() try: cache_manager = set_cache_manager( @@ -841,11 +862,14 @@ def warmup(self, batch: FlashCausalLMBatch): self.device, ) _, batch = self.generate_token(batch) - except Exception as e: - raise RuntimeError( - f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " - f"You need to decrease `--max-batch-prefill-tokens`" - ) from e + except RuntimeError as e: + if "CUDA out of memory" in str(e) or isinstance(e, torch.cuda.OutOfMemoryError): + raise RuntimeError( + f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " + f"You need to decrease `--max-batch-prefill-tokens`" + ) from e + else: + raise torch.cuda.synchronize(self.device) diff --git a/server/lorax_server/models/flash_qwen.py b/server/lorax_server/models/flash_qwen.py index a75a575ef..f87e10f31 100644 --- a/server/lorax_server/models/flash_qwen.py +++ b/server/lorax_server/models/flash_qwen.py @@ -24,6 +24,7 @@ ) from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID from lorax_server.utils.lora import LM_HEAD +from lorax_server.utils.weights import shard_on_dim tracer = trace.get_tracer(__name__) @@ -95,6 +96,7 @@ def __init__( self.model_id = model_id model = FlashQwenForCausalLM(config, weights) + self.config = config torch.distributed.barrier(group=self.process_group) super(FlashQwen, self).__init__( @@ -137,3 +139,42 @@ def get_num_layers_for_type(self, layer_type: str) -> int: def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL + + def split_lora_b_qkv(self, t: torch.Tensor, projection_size: int) -> torch.Tensor: + # Because we're splitting on the hidden size dimension, we need to + # account for the separate q, k, and v matrices. + chunks = torch.split(t, projection_size, dim=1) + assert len(chunks) == 3 + chunks = [ + shard_on_dim(w, dim=1, process_group=self.process_group) + for w in chunks + ] + return torch.cat(chunks, dim=1) + + def shard_lora_weights( + self, + weights_a: List[torch.Tensor], + weights_b: List[torch.Tensor], + layer_type: str, + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + # TODO(travis): genralize this for other layers and architectures + if layer_type == ATTN_C_ATTN: + # [hidden_size, r] + split_dim = 0 if self.is_row_parallel(layer_type) else 1 + weights_a = [ + shard_on_dim(w, dim=split_dim, process_group=self.process_group) + for w in weights_a + ] + + # [r, hidden_size] + # Because we're splitting on the hidden size dimension, we need to + # account for the separate q, k, and v matrices. + projection_size = (self.config.hidden_size // self.config.num_attention_heads) * self.config.num_attention_heads + weights_b = [ + self.split_lora_b_qkv(w, projection_size) + for w in weights_b + ] + + return weights_a, weights_b + else: + return super().shard_lora_weights(weights_a, weights_b, layer_type) diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index dd23f93c3..33861d808 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -5,7 +5,7 @@ from torch import nn from torch.nn import functional as F -from typing import List +from typing import List, Tuple, Union HAS_BITS_AND_BYTES = True try: @@ -368,7 +368,7 @@ def load(cls, config, prefix: str, weights, bias: bool, fan_in_fan_out: bool = F def load_multi( cls, config, - prefixes: List[str], + prefixes: List[Union[str, Tuple]], weights, bias: bool, dim: int, @@ -379,7 +379,7 @@ def load_multi( ) if bias: - b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] + b = weights.get_sharded_list("bias", prefixes, dim=0) bias = torch.cat(b, dim=dim) else: bias = None diff --git a/server/lorax_server/utils/lora.py b/server/lorax_server/utils/lora.py index 28af816ef..05b2e1bc3 100644 --- a/server/lorax_server/utils/lora.py +++ b/server/lorax_server/utils/lora.py @@ -82,20 +82,15 @@ def __init__( weights_a: List[torch.Tensor], weights_b: List[torch.Tensor], adapter_config: LoraConfig, - layer_type: str, - process_group: ProcessGroup, - is_row_parallel: bool, ): # [num_layers, hidden_size, r] - split_dim = 0 if is_row_parallel else 1 weights_a = [ - orient_for_rank(shard_on_dim(w, dim=split_dim, process_group=process_group), adapter_config.r) + orient_for_rank(w, adapter_config.r) for w in weights_a ] self.weights_a = torch.stack(weights_a) # [num_layers, r, hidden_size] - weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b] self.weights_b = torch.stack(weights_b) self.adapter_config = adapter_config diff --git a/server/lorax_server/utils/weights.py b/server/lorax_server/utils/weights.py index 01d50c9d6..b786751b8 100644 --- a/server/lorax_server/utils/weights.py +++ b/server/lorax_server/utils/weights.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import List, Dict, Optional, Tuple +from typing import List, Dict, Optional, Tuple, Union from safetensors import safe_open, SafetensorError import torch from loguru import logger @@ -114,15 +114,31 @@ def get_tensor(self, tensor_name: str): tensor = tensor.to(device=self.device) return tensor - def get_partial_sharded(self, tensor_name: str, dim: int): + def get_partial_sharded(self, tensor_name: str, dim: int, range: Optional[Tuple[int, int]] = None): + """Loads tensor with the given name and shards it along the given dimension. + + The optional range argument can be used to load and split on only a subset of the tensor. + This is useful in cases where the tensor is stored as one contiguous block, but is logically + split into different components that need to be sharded separately. For example, when storing + QKV weights together as a single tensor on disk. + + Args: + tensor_name (str): Name of the tensor to load. + dim (int): Dimension to shard along. + range (Optional[Tuple[int, int]]): Range of indices to load and shard as (offset, size). + """ filename, tensor_name = self.get_filename(tensor_name) world_size = self.process_group.size() rank = self.process_group.rank() f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) - size = slice_.get_shape()[dim] - start, stop = get_start_stop_idxs_for_rank(size, rank, world_size) + if range is not None: + offset, size = range + else: + offset = 0 + size = slice_.get_shape()[dim] + start, stop = get_start_stop_idxs_for_rank(offset, size, rank, world_size) if dim == 0: tensor = slice_[start:stop] @@ -137,22 +153,33 @@ def get_partial_sharded(self, tensor_name: str, dim: int): tensor = tensor.to(device=self.device) return tensor - def get_sharded(self, tensor_name: str, dim: int): + def get_sharded(self, tensor_name: str, dim: int, range: Optional[Tuple[int, int]] = None): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) world_size = self.process_group.size() - size = slice_.get_shape()[dim] + size = slice_.get_shape()[dim] if range is None else range[1] assert ( size % world_size == 0 ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" - return self.get_partial_sharded(tensor_name, dim) + return self.get_partial_sharded(tensor_name, dim, range=range) + + def get_sharded_prefix(self, module_name: str, prefix: Union[str, Tuple], dim: int): + if isinstance(prefix, str): + return self.get_sharded(f"{prefix}.{module_name}", dim=dim) + else: + assert isinstance(prefix, tuple) + assert len(prefix) == 2 + return self.get_sharded(f"{prefix[0]}.{module_name}", dim=dim, range=prefix[1]) + + def get_sharded_list(self, module_name: str, prefixes: List[Union[str, Tuple]], dim: int): + return [self.get_sharded_prefix(module_name, p, dim=dim) for p in prefixes] - def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): + def get_multi_weights_col(self, prefixes: List[Union[str, Tuple]], quantize: str, dim: int): if quantize in ["gptq", "awq"]: try: qweight = torch.cat( - [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 + self.get_sharded_list("qweight", prefixes, dim=1), dim=1 ) except RuntimeError: raise RuntimeError( @@ -160,12 +187,14 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): ) qzeros = torch.cat( - [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + self.get_sharded_list("qzeros", prefixes, dim=1), dim=1 ) scales = torch.cat( - [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + self.get_sharded_list("scales", prefixes, dim=1), dim=1 ) if quantize == "gptq": + # no tensor parallelism, so remove the range if provided + prefixes = [p[0] if isinstance(p, tuple) else p for p in prefixes] w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] for w2 in w[1:]: torch.testing.assert_close(w2, w[0]) @@ -176,7 +205,7 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): bits, groupsize = self._get_gptq_params() weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) else: - w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] + w = self.get_sharded_list("weight", prefixes, dim=0) weight = torch.cat(w, dim=dim) return weight @@ -314,10 +343,10 @@ def _set_gptq_params(self, model_id): except Exception: pass -def get_start_stop_idxs_for_rank(size, rank, world_size): +def get_start_stop_idxs_for_rank(offset, size, rank, world_size): block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size + start = offset + rank * block_size + stop = offset + (rank + 1) * block_size return start, stop @@ -326,7 +355,7 @@ def shard_on_dim(t: torch.Tensor, dim: int, process_group: torch.distributed.Pro rank = process_group.rank() size = t.shape[dim] - start, stop = get_start_stop_idxs_for_rank(size, rank, world_size) + start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size) if dim == 0: tensor = t[start:stop]