Skip to content

Commit

Permalink
[Model] Composite weight loading for multimodal Qwen2 (vllm-project#1…
Browse files Browse the repository at this point in the history
…0944)

Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored and weilong.yu committed Dec 13, 2024
1 parent d39c88c commit 0b1345f
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 205 deletions.
10 changes: 9 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2472,7 +2472,15 @@ def _get_quantization_config(
return quant_config
return None

def with_hf_config(self, hf_config: PretrainedConfig) -> "VllmConfig":
def with_hf_config(
self,
hf_config: PretrainedConfig,
architectures: Optional[list[str]] = None,
) -> "VllmConfig":
if architectures is not None:
hf_config = copy.deepcopy(hf_config)
hf_config.architectures = architectures

model_config = copy.deepcopy(self.model_config)
model_config.hf_config = hf_config

Expand Down
4 changes: 1 addition & 3 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,10 @@ def _initialize_model(
vllm_config: VllmConfig,
*,
prefix: str = "",
architectures: Optional[list[str]] = None,
) -> nn.Module:
"""Initialize a model with the given configurations."""
model_config = vllm_config.model_config
model_class, _ = get_model_architecture(model_config,
architectures=architectures)
model_class, _ = get_model_architecture(model_config)

signatures = inspect.signature(model_class.__init__)
all_params = [param.name for param in signatures.parameters.values()]
Expand Down
10 changes: 3 additions & 7 deletions vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Utilities for selecting and loading models."""
import contextlib
from typing import Optional, Tuple, Type
from typing import Tuple, Type

import torch
from torch import nn
Expand All @@ -20,12 +20,8 @@ def set_default_torch_dtype(dtype: torch.dtype):


def get_model_architecture(
model_config: ModelConfig,
*,
architectures: Optional[list[str]] = None,
) -> Tuple[Type[nn.Module], str]:
if architectures is None:
architectures = getattr(model_config.hf_config, "architectures", [])
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])

# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
Expand Down
17 changes: 10 additions & 7 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,14 +444,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))

if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
self.lm_head = PPMissingLayer()

self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
Expand Down
117 changes: 32 additions & 85 deletions vllm/model_executor/models/qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from functools import lru_cache
from functools import cached_property, lru_cache
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)

Expand All @@ -34,28 +34,19 @@
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import IntermediateTensors, SequenceData

from .interfaces import SupportsMultiModal, SupportsPP
from .utils import merge_multimodal_embeddings
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)

logger = init_logger(__name__)

_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}


# # === Audio Inputs === #
class Qwen2AudioInputs(TypedDict):
Expand Down Expand Up @@ -281,25 +272,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.quant_config = quant_config

self.language_model = Qwen2Model(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=prefix)
self.unpadded_vocab_size = config.text_config.vocab_size
if config.text_config.tie_word_embeddings:
self.lm_head = self.language_model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.text_config.vocab_size,
config.text_config.hidden_size,
quant_config=quant_config)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.text_config.vocab_size,
logit_scale)
self.sampler = get_sampler()
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)

self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)

@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler

return get_sampler()

def _validate_and_reshape_mm_tensor(self,
mm_input: Union[torch.Tensor,
List[torch.Tensor]],
Expand Down Expand Up @@ -414,72 +403,30 @@ def forward(
multimodal_embeddings)
input_ids = None

hidden_states = self.language_model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)

def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
return self.language_model.sample(logits, sampling_metadata)

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if (self.config.text_config.tie_word_embeddings
and "lm_head.weight" in name):
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name or 'audio' in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
Loading

0 comments on commit 0b1345f

Please sign in to comment.