From c82b432d4a40fd6376a35fd38cb5fc37e9c53798 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 29 Nov 2024 13:17:57 +0800 Subject: [PATCH 1/6] [Misc] typo find in sampling_metadata.py (#10740) --- vllm/model_executor/sampling_metadata.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 84f35f75a0c32..1df8f84ed4093 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -454,6 +454,7 @@ def from_sampling_metadata( if do_penalties: for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params if (seq_group.is_prompt and sampling_params.prompt_logprobs is not None): prefill_len = len(seq_group.prompt_logprob_indices) From 3132aac04326286ae996bf0887e920096b2bb210 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 29 Nov 2024 21:56:46 +0800 Subject: [PATCH 2/6] [Bugfix] Fix Idefics3 bug (#10778) Signed-off-by: Jee Jee Li --- vllm/model_executor/models/idefics3.py | 92 +++++++++++++------------- 1 file changed, 47 insertions(+), 45 deletions(-) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 014e27bc869d4..e5d2edbd81eb1 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -267,54 +267,56 @@ def input_processor_for_idefics3(ctx: InputContext, n_images_in_text = [] text = inputs.get("prompt") - if text is not None: - if isinstance(text, str): - text = [text] - elif not isinstance(text, list) and not isinstance(text[0], str): - raise ValueError("Invalid input text. Please provide a string, " - "or a list of strings") - - fake_image_token = processor.fake_image_token.content - image_token = processor.image_token.content - global_img_token = processor.global_image_tag - - prompt_strings = [] - for sample, sample_rows, sample_cols in zip(text, image_rows, - image_cols): - n_images_in_text.append(sample.count(image_token)) - - # Replace the image token with fake tokens around the expanded - # image token sequence of length `image_seq_len` - image_prompt_strings = [] - for n_rows, n_cols in zip(sample_rows, sample_cols): - image_prompt_string = _get_image_prompt_string( - n_rows, - n_cols, - processor.image_seq_len, - image_token=image_token, - fake_token_around_image=fake_image_token, - global_img_token=global_img_token, - ) - image_prompt_strings.append(image_prompt_string) - - split_sample = sample.split(image_token) - if len(split_sample) == 0: - raise ValueError( - "The image token should be present in the text.") + if text is None: + prompt_token_ids = inputs.get("prompt_token_ids", []) + assert prompt_token_ids + text = tokenizer.decode(prompt_token_ids) + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, " + "or a list of strings") + + fake_image_token = processor.fake_image_token.content + image_token = processor.image_token.content + global_img_token = processor.global_image_tag + + prompt_strings = [] + for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols): + n_images_in_text.append(sample.count(image_token)) + + # Replace the image token with fake tokens around the expanded + # image token sequence of length `image_seq_len` + image_prompt_strings = [] + for n_rows, n_cols in zip(sample_rows, sample_cols): + image_prompt_string = _get_image_prompt_string( + n_rows, + n_cols, + processor.image_seq_len, + image_token=image_token, + fake_token_around_image=fake_image_token, + global_img_token=global_img_token, + ) + image_prompt_strings.append(image_prompt_string) - # Place in the image prompt strings where the image tokens are - sample = split_sample[0] - for i, image_prompt_string in enumerate(image_prompt_strings): - sample += image_prompt_string + split_sample[i + 1] - prompt_strings.append(sample) + split_sample = sample.split(image_token) + if len(split_sample) == 0: + raise ValueError("The image token should be present in the text.") - prompt_token_ids = tokenizer(text=prompt_strings[0]).input_ids + # Place in the image prompt strings where the image tokens are + sample = split_sample[0] + for i, image_prompt_string in enumerate(image_prompt_strings): + sample += image_prompt_string + split_sample[i + 1] + prompt_strings.append(sample) - return token_inputs( - prompt_token_ids=prompt_token_ids, - prompt=prompt_strings[0], - multi_modal_data=multi_modal_data, - ) + prompt_token_ids = tokenizer(text=prompt_strings[0]).input_ids + + return token_inputs( + prompt_token_ids=prompt_token_ids, + prompt=prompt_strings[0], + multi_modal_data=multi_modal_data, + ) def _get_max_num_image_patch(image_processor: Idefics3ImageProcessor) -> int: From 661175bc826f4caba04182a1faeeca9e7a3259ac Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Fri, 29 Nov 2024 23:22:21 +0800 Subject: [PATCH 3/6] [platform] Add verify_quantization in platform. (#10757) Signed-off-by: wangxiyuan --- vllm/config.py | 28 +--------------------------- vllm/platforms/cpu.py | 1 + vllm/platforms/cuda.py | 1 + vllm/platforms/hpu.py | 1 + vllm/platforms/interface.py | 13 +++++++++++++ vllm/platforms/neuron.py | 2 ++ vllm/platforms/openvino.py | 1 + vllm/platforms/rocm.py | 15 +++++++++++++++ vllm/platforms/tpu.py | 2 ++ vllm/platforms/xpu.py | 1 + 10 files changed, 38 insertions(+), 27 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index cd24e9ffdf598..b1e5b412fec8f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -393,17 +393,11 @@ def _parse_quant_hf_config(self): def _verify_quantization(self) -> None: supported_quantization = QUANTIZATION_METHODS - rocm_supported_quantization = [ - "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", - "fbgemm_fp8", "gguf" - ] optimized_quantization_methods = [ "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed_tensors", "compressed-tensors", "experts_int8" ] - tpu_supported_quantization = ["tpu_int8"] - neuron_supported_quantization = ["neuron_quant"] if self.quantization is not None: self.quantization = self.quantization.lower() @@ -438,32 +432,12 @@ def _verify_quantization(self) -> None: raise ValueError( f"Unknown quantization method: {self.quantization}. Must " f"be one of {supported_quantization}.") - if current_platform.is_rocm( - ) and self.quantization not in rocm_supported_quantization: - raise ValueError( - f"{self.quantization} quantization is currently not " - f"supported in ROCm.") - if current_platform.is_tpu( - ) and self.quantization not in tpu_supported_quantization: - raise ValueError( - f"{self.quantization} quantization is currently not " - f"supported in TPU Backend.") + current_platform.verify_quantization(self.quantization) if self.quantization not in optimized_quantization_methods: logger.warning( "%s quantization is not fully " "optimized yet. The speed can be slower than " "non-quantized models.", self.quantization) - if (self.quantization == "awq" and current_platform.is_rocm() - and not envs.VLLM_USE_TRITON_AWQ): - logger.warning( - "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" - " is not set, enabling VLLM_USE_TRITON_AWQ.") - envs.VLLM_USE_TRITON_AWQ = True - if current_platform.is_neuron( - ) and self.quantization not in neuron_supported_quantization: - raise ValueError( - f"{self.quantization} quantization is currently not " - f"supported in Neuron Backend.") def _verify_cuda_graph(self) -> None: if self.max_seq_len_to_capture is None: diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 3e22c87f61fac..b5333fbd6f502 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -19,6 +19,7 @@ class CpuPlatform(Platform): _enum = PlatformEnum.CPU + device_name: str = "cpu" device_type: str = "cpu" dispatch_key: str = "CPU" diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 5e9ce551f2332..846a1869da228 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -72,6 +72,7 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: class CudaPlatformBase(Platform): _enum = PlatformEnum.CUDA + device_name: str = "cuda" device_type: str = "cuda" dispatch_key: str = "CUDA" diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 3071136e43b85..10aaa6d54962c 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -12,6 +12,7 @@ class HpuPlatform(Platform): _enum = PlatformEnum.HPU + device_name: str = "hpu" device_type: str = "hpu" dispatch_key: str = "HPU" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 3328665029039..eac2b413f9271 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -56,11 +56,13 @@ def to_int(self) -> int: class Platform: _enum: PlatformEnum + device_name: str device_type: str # available dispatch keys: # check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa # use "CPU" as a fallback for platforms not registered in PyTorch dispatch_key: str = "CPU" + supported_quantization: list[str] = [] def is_cuda(self) -> bool: return self._enum == PlatformEnum.CUDA @@ -171,6 +173,17 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: """ pass + @classmethod + def verify_quantization(cls, quant: str) -> None: + """ + Verify whether the quantization is supported by the current platform. + """ + if cls.supported_quantization and \ + quant not in cls.supported_quantization: + raise ValueError( + f"{quant} quantization is currently not supported in " + f"{cls.device_name}.") + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 4c4d778ed3dd4..87655ea198303 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -10,7 +10,9 @@ class NeuronPlatform(Platform): _enum = PlatformEnum.NEURON + device_name: str = "neuron" device_type: str = "neuron" + supported_quantization: list[str] = ["neuron_quant"] @classmethod def get_device_name(cls, device_id: int = 0) -> str: diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index ea5ec7b40b95c..29b61e955d9ab 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -23,6 +23,7 @@ class OpenVinoPlatform(Platform): _enum = PlatformEnum.OPENVINO + device_name: str = "openvino" device_type: str = "openvino" dispatch_key: str = "CPU" diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index d2f44c3e423e3..3c14fbc179f69 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -4,6 +4,7 @@ import torch +import vllm.envs as envs from vllm.logger import init_logger from .interface import DeviceCapability, Platform, PlatformEnum, _Backend @@ -35,8 +36,13 @@ class RocmPlatform(Platform): _enum = PlatformEnum.ROCM + device_name: str = "rocm" device_type: str = "cuda" dispatch_key: str = "CUDA" + supported_quantization: list[str] = [ + "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", + "fbgemm_fp8", "gguf" + ] @classmethod def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: @@ -79,3 +85,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "vllm.spec_decode.spec_decode_worker.create_spec_worker" else: parallel_config.worker_cls = "vllm.worker.worker.Worker" + + @classmethod + def verify_quantization(cls, quant: str) -> None: + super().verify_quantization(quant) + if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ: + logger.warning( + "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" + " is not set, enabling VLLM_USE_TRITON_AWQ.") + envs.VLLM_USE_TRITON_AWQ = True diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 137af57023ea9..b138f7e1c54c5 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -16,8 +16,10 @@ class TpuPlatform(Platform): _enum = PlatformEnum.TPU + device_name: str = "tpu" device_type: str = "tpu" dispatch_key: str = "XLA" + supported_quantization: list[str] = ["tpu_int8"] @classmethod def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 69388a8e0f27c..9665786f4c499 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -16,6 +16,7 @@ class XPUPlatform(Platform): _enum = PlatformEnum.XPU + device_name: str = "xpu" device_type: str = "xpu" dispatch_key: str = "XPU" From 40bc242579d260e6da7614e1494cbd80a6f985b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Sat, 30 Nov 2024 05:07:13 +0100 Subject: [PATCH 4/6] [Bugfix] Fix OpenVino/Neuron `driver_worker` init (#10779) Signed-off-by: NickLucche Signed-off-by: Cyrus Leung Co-authored-by: Cyrus Leung --- vllm/executor/neuron_executor.py | 6 ++++-- vllm/executor/openvino_executor.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 31e6fdc3ab1bb..a9efc4f9a801c 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -29,11 +29,13 @@ def _init_worker(self): wrapper = WorkerWrapperBase(vllm_config=self.vllm_config) distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) - self.driver_worker = wrapper.init_worker( + wrapper.init_worker( vllm_config=self.vllm_config, local_rank=0, rank=0, - distributed_init_method=distributed_init_method) + distributed_init_method=distributed_init_method, + ) + self.driver_worker = wrapper.worker self.driver_worker.init_device() self.driver_worker.load_model() diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py index db0070ce510ee..057a32364e512 100644 --- a/vllm/executor/openvino_executor.py +++ b/vllm/executor/openvino_executor.py @@ -36,7 +36,7 @@ def _init_worker(self): distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) - self.driver_worker = wrapper.init_worker( + wrapper.init_worker( ov_core=ov.Core(), vllm_config=self.vllm_config, local_rank=0, @@ -45,6 +45,7 @@ def _init_worker(self): kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) + self.driver_worker = wrapper.worker self.driver_worker.init_device() self.driver_worker.load_model() From 16ee07f22ade57eb882b3c16ad3a6944635996df Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sat, 30 Nov 2024 12:19:14 +0800 Subject: [PATCH 5/6] [Model] Refactor Molmo weights loading to use AutoWeightsLoader (#10771) Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/molmo.py | 213 +++++++++++++++------------- 1 file changed, 111 insertions(+), 102 deletions(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index acedddd84d7cb..98caa6857e211 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -3,7 +3,7 @@ from array import array from dataclasses import dataclass from functools import lru_cache, partial -from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict +from typing import Iterable, List, Mapping, Optional, Set, Tuple, TypedDict import torch from einops import rearrange @@ -44,7 +44,8 @@ from vllm.transformers_utils.processor import get_processor from .interfaces import SupportsMultiModal, SupportsPP -from .utils import (get_vit_attn_backend, +from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -720,6 +721,42 @@ def forward( # image_features: (batch_size, num_image, num_patch, d_model) return image_features + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + @support_torch_compile class MolmoModel(nn.Module): @@ -804,6 +841,28 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + if "gate_up_proj" in name: + up_proj, gate_proj = loaded_weight.chunk(2, dim=0) + loaded_weight = torch.cat([gate_proj, up_proj], dim=0) + + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + cached_get_processor = lru_cache(get_processor) @@ -1200,103 +1259,53 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - - params_mapping = [ - ("model.transformer.ln_f.weight", "model.norm.weight"), - ("attn_out", "self_attn.o_proj"), - ("att_proj", "self_attn.qkv_proj"), - ("q_norm", "self_attn.q_norm"), - ("k_norm", "self_attn.k_norm"), - ("attn_norm", "input_layernorm"), - ("ff_norm", "post_attention_layernorm"), - ] - - params_dict = dict(self.named_parameters(remove_duplicate=False)) - - embedding_weight = dict() - projector_weight = dict() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if self.config.tie_word_embeddings and "lm_head.weight" in name: - continue - - if "wte.embedding" in name: - embedding_weight["embedding"] = loaded_weight - continue - - if "wte.new_embedding" in name: - embedding_weight["new_embedding"] = loaded_weight - continue - - if "vision_backbone" in name: - if name.startswith("model"): - name = name[len("model."):] - if 'image_projector' in name: - if 'w1' in name: - projector_weight['gate_proj'] = loaded_weight - elif 'w3' in name: - projector_weight['up_proj'] = loaded_weight - elif 'w2' in name: - projector_weight['down_proj'] = loaded_weight - else: - raise ValueError( - f"Unexpected projector weight: {name}") - continue - else: - if "transformer.blocks" in name: - name = name.replace("transformer.blocks", "layers") - - if "ff_proj" in name: - name = name.replace("ff_proj", "mlp.gate_up_proj") - assert 'weight' in name - up_weight, gate_weight = loaded_weight.chunk(2, dim=0) - loaded_weight = torch.cat([gate_weight, up_weight], dim=0) - - elif "ff_out" in name: - if "layers" in name: - name = name.replace("ff_out", "mlp.down_proj") - else: - # lm head - name = name.replace("model.transformer.ff_out", - "lm_head") - - else: - for (param_name, weight_name) in params_mapping: - if param_name in name: - name = name.replace(param_name, weight_name) - break - - try: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - except KeyError: - raise ValueError(f"Unexpected weight: {name}") from None - - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - gate_up_proj_weight = torch.cat( - [projector_weight["gate_proj"], projector_weight["up_proj"]], - dim=0) - name = "vision_backbone.image_projector.gate_up_proj.weight" - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, gate_up_proj_weight) - - down_proj_weight = projector_weight["down_proj"] - name = "vision_backbone.image_projector.down_proj.weight" - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, down_proj_weight) - - embedding_weight = torch.cat( - [embedding_weight["embedding"], embedding_weight["new_embedding"]], - dim=0) - name = "model.embed_tokens.weight" - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, embedding_weight) + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + # vision backbone mapping + "image_projector.w1.": "image_projector.gate_proj.", + "image_projector.w3.": "image_projector.up_proj.", + "image_projector.w2.": "image_projector.down_proj.", + # language backbone mapping + "att_proj": "self_attn.qkv_proj", + "attn_out": "self_attn.o_proj", + "q_norm": "self_attn.q_norm", + "k_norm": "self_attn.k_norm", + "ff_proj": "mlp.gate_up_proj", + "ff_out": "mlp.down_proj", + "attn_norm": "input_layernorm", + "ff_norm": "post_attention_layernorm", + }, + orig_to_new_prefix={ + # vision backbone mapping + "model.vision_backbone.": "vision_backbone.", + # language backbone mapping + "model.transformer.blocks.": "model.layers.", + "model.transformer.ln_f.": "model.norm.", + # lm_head is renamed to model.transformer.mlp.down_proj firstly, + # we need to run a second renaming for it + "model.transformer.mlp.down_proj.": "lm_head.", + }, + ) + loader = AutoWeightsLoader(self) + weights = _get_weights_with_merged_embedding(weights) + return loader.load_weights(weights, mapper=hf_to_vllm_mapper) + + +def _get_weights_with_merged_embedding( + weights: Iterable[Tuple[str, torch.Tensor]] +) -> Iterable[Tuple[str, torch.Tensor]]: + embedding_weights = {} + for name, weight in weights: + if "wte.embedding" in name: + embedding_weights["embedding"] = weight + elif "wte.new_embedding" in name: + embedding_weights["new_embedding"] = weight + else: + yield (name, weight) + # this is compatible with most of quantization, + # because they won't quantize embed_tokens + embedding_weights = torch.cat( + [embedding_weights["embedding"], embedding_weights["new_embedding"]], + dim=0, + ) + yield ("model.embed_tokens.weight", embedding_weights) From e7cfc4ef4cc017e0a0229adff9f4b143b38fb421 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 30 Nov 2024 08:45:50 +0100 Subject: [PATCH 6/6] [Interleaved ATTN] Support for Mistral-8B (#10591) Signed-off-by: youkaichao Co-authored-by: youkaichao --- vllm/model_executor/models/llama.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index fe94bb352961b..ff0ab011a9158 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -54,7 +54,7 @@ from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, - is_pp_missing_parameter, + extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -114,6 +114,7 @@ def __init__( prefix: str = "", ) -> None: super().__init__() + layer_idx = extract_layer_index(prefix) self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads @@ -168,6 +169,18 @@ def __init__( rope_scaling=rope_scaling, is_neox_style=is_neox_style, ) + + if hasattr(config, "interleaved_sliding_window"): + if isinstance(config.interleaved_sliding_window, int): + sliding_window = config.interleaved_sliding_window + elif isinstance(config.interleaved_sliding_window, list): + sw_idx = layer_idx % len(config.interleaved_sliding_window) + sliding_window = config.interleaved_sliding_window[sw_idx] + else: + raise ValueError(f"{type(sliding_window)} is not supported.") + else: + sliding_window = None + self.attn = Attention( self.num_heads, self.head_dim, @@ -175,6 +188,7 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, + per_layer_sliding_window=sliding_window, prefix=f"{prefix}.attn", )