From f19da64871065510691cd4fcaa5f4096b661dcec Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 7 Oct 2024 18:01:46 +0800 Subject: [PATCH] [Core] Refactor GGUF parameters packing and forwarding (#8859) --- .../models/decoder_only/language/test_gguf.py | 12 +-- vllm/model_executor/layers/linear.py | 76 ++++++++----------- .../layers/quantization/gguf.py | 36 ++++++--- vllm/model_executor/models/llama.py | 2 +- 4 files changed, 64 insertions(+), 62 deletions(-) diff --git a/tests/models/decoder_only/language/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py index 8fc64a10c84af..5dc83942632fd 100644 --- a/tests/models/decoder_only/language/test_gguf.py +++ b/tests/models/decoder_only/language/test_gguf.py @@ -19,12 +19,12 @@ # FIXME: Move this to confest MODELS = [ - ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", - hf_hub_download("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", - filename="tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")), - ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", - hf_hub_download("duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF", - filename="TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf")), + ("meta-llama/Llama-3.2-1B-Instruct", + hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF", + filename="Llama-3.2-1B-Instruct-Q4_K_M.gguf")), + ("meta-llama/Llama-3.2-1B-Instruct", + hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF", + filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf")), ("Qwen/Qwen2-1.5B-Instruct", hf_hub_download("Qwen/Qwen2-1.5B-Instruct-GGUF", filename="qwen2-1_5b-instruct-q4_k_m.gguf")), diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 568892778abe2..c162ab81c5530 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -440,17 +440,23 @@ def weight_loader(self, param.shard_weight_type[loaded_shard_id] = loaded_weight.item() return - if is_gguf_weight and isinstance(param, UninitializedParameter): - from gguf.constants import GGML_QUANT_SIZES + if is_gguf_weight: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + output_dim = getattr(param, "output_dim", None) + shard_size = loaded_weight.size(output_dim) // tp_size + start_idx = tp_rank * shard_size + + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) - ori_shape = param.tensor_shape - weight_types = self.qweight_type.shard_weight_type.values() - row_size = [] - for weight_type in weight_types: - block_size, type_size = GGML_QUANT_SIZES[weight_type] - row_size.append(ori_shape[1] // block_size * type_size) - q_shape = (ori_shape[0], max(row_size)) - param.materialize(q_shape, dtype=loaded_weight.dtype) + param.shard_id.append(loaded_shard_id) + param.shard_id_map[loaded_shard_id] = len(param.data_container) + param.data_container.append(loaded_weight) + if len(param.data_container) == 2: + self.qweight = param.materialize_nested() + return param_data = param.data output_dim = getattr(param, "output_dim", None) @@ -515,18 +521,6 @@ def weight_loader(self, shard_offset = loaded_weight.shape[output_dim] * \ loaded_shard_id - if is_gguf_weight: - tp_size = get_tensor_model_parallel_world_size() - output_dim = getattr(param, "output_dim", None) - shard_shape = list(loaded_weight.shape) - shard_shape[output_dim] = shard_shape[output_dim] // tp_size - param.shard_id.append(loaded_shard_id) - param.shard_size[loaded_shard_id] = shard_shape - - input_dim = getattr(param, "input_dim", None) - input_size = loaded_weight.shape[input_dim] - param_data = param_data.narrow(input_dim, 0, input_size) - param_data = param_data.narrow(output_dim, shard_offset, shard_size) start_idx = tp_rank * shard_size @@ -783,17 +777,23 @@ def weight_loader(self, param.shard_weight_type[loaded_shard_id] = loaded_weight.item() return - if is_gguf_weight and isinstance(param, UninitializedParameter): - from gguf.constants import GGML_QUANT_SIZES + if is_gguf_weight: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() - ori_shape = param.tensor_shape - weight_types = self.qweight_type.shard_weight_type.values() - row_size = [] - for weight_type in weight_types: - block_size, type_size = GGML_QUANT_SIZES[weight_type] - row_size.append(ori_shape[1] // block_size * type_size) - q_shape = (ori_shape[0], max(row_size)) - param.materialize(q_shape, dtype=loaded_weight.dtype) + output_dim = getattr(param, "output_dim", None) + shard_size = loaded_weight.size(output_dim) // tp_size + start_idx = tp_rank * shard_size + + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + + param.shard_id.append(loaded_shard_id) + param.shard_id_map[loaded_shard_id] = len(param.data_container) + param.data_container.append(loaded_weight) + if len(param.data_container) == 3: + self.qweight = param.materialize_nested() + return param_data = param.data output_dim = getattr(param, "output_dim", None) @@ -883,18 +883,6 @@ def weight_loader(self, shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( param, orig_qkv_offsets, loaded_shard_id) - if is_gguf_weight: - tp_size = get_tensor_model_parallel_world_size() - output_dim = getattr(param, "output_dim", None) - shard_shape = list(loaded_weight.shape) - shard_shape[output_dim] = shard_shape[output_dim] // tp_size - param.shard_id.append(loaded_shard_id) - param.shard_size[loaded_shard_id] = shard_shape - - input_dim = getattr(param, "input_dim", None) - input_size = loaded_weight.shape[input_dim] - param_data = param_data.narrow(input_dim, 0, input_size) - param_data = param_data.narrow(output_dim, shard_offset, shard_size) if loaded_shard_id == "q": diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index dc83017bcc7f9..d73b9f6d92832 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -86,15 +86,16 @@ def create_weights(self, layer: torch.nn.Module, output_size_per_partition = sum(output_partition_sizes) tensor_shape = (output_size_per_partition, input_size_per_partition) - qweight = UninitializedParameter(requires_grad=False) + qweight = GGUFUninitializedParameter(requires_grad=False) set_weight_attrs( qweight, { "input_dim": 1, "output_dim": 0, "tensor_shape": tensor_shape, "is_gguf_weight": True, - "shard_size": {}, + "data_container": [], "shard_id": [], + "shard_id_map": {}, }) set_weight_attrs(qweight, extra_weight_attrs) layer.register_parameter("qweight", qweight) @@ -116,21 +117,17 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - shard_size = getattr(layer.qweight, "shard_size", None) shard_id = getattr(layer.qweight, "shard_id", None) - if shard_id and shard_size: - result = [] - offset = 0 + if shard_id: # dequantize shard weights respectively shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id + qweight = layer.qweight.unbind(0) + result = [] for id in shard_id: - shard_weight = layer.qweight[ - offset:offset + - shard_size[id][0], :shard_size[id][1]].contiguous() + q_idx = layer.qweight.shard_id_map[id] qweight_type = layer.qweight_type.shard_weight_type[id] - result.append(_fuse_mul_mat(x, shard_weight, qweight_type)) - offset += shard_size[id][0] + result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type)) out = torch.cat(result, axis=1) else: qweight = layer.qweight @@ -162,3 +159,20 @@ def embedding(self, layer: torch.nn.Module, dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size, x_flat.shape[0]) return dequant.view(*x.shape, hidden_size) + + +class GGUFUninitializedParameter(UninitializedParameter): + cls_to_become = Parameter + data_container: List[torch.Tensor] + + def materialize_nested(self) -> Parameter: + nested_data = torch.nested.nested_tensor(self.data_container, + device=self.device, + dtype=torch.uint8) + self.data_container.clear() + param = torch.Tensor._make_subclass(self.cls_to_become, + nested_data, + require_grad=False) + for k, v in self.__dict__.items(): + setattr(param, k, v) + return param diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d591d20f7f2f2..8eacf73dd6322 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -512,7 +512,7 @@ def __init__( quant_config=quant_config, ) if config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight + self.lm_head = self.model.embed_tokens logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,