Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Composite weight loading for multimodal Qwen2 #10944

Merged
merged 5 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2470,7 +2470,15 @@ def _get_quantization_config(
return quant_config
return None

def with_hf_config(self, hf_config: PretrainedConfig) -> "VllmConfig":
def with_hf_config(
self,
hf_config: PretrainedConfig,
architectures: Optional[list[str]] = None,
) -> "VllmConfig":
if architectures is not None:
hf_config = copy.deepcopy(hf_config)
hf_config.architectures = architectures

model_config = copy.deepcopy(self.model_config)
model_config.hf_config = hf_config

Expand Down
4 changes: 1 addition & 3 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,10 @@ def _initialize_model(
vllm_config: VllmConfig,
*,
prefix: str = "",
architectures: Optional[list[str]] = None,
) -> nn.Module:
"""Initialize a model with the given configurations."""
model_config = vllm_config.model_config
model_class, _ = get_model_architecture(model_config,
architectures=architectures)
model_class, _ = get_model_architecture(model_config)

signatures = inspect.signature(model_class.__init__)
all_params = [param.name for param in signatures.parameters.values()]
Expand Down
10 changes: 3 additions & 7 deletions vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Utilities for selecting and loading models."""
import contextlib
from typing import Optional, Tuple, Type
from typing import Tuple, Type

import torch
from torch import nn
Expand All @@ -20,12 +20,8 @@ def set_default_torch_dtype(dtype: torch.dtype):


def get_model_architecture(
model_config: ModelConfig,
*,
architectures: Optional[list[str]] = None,
) -> Tuple[Type[nn.Module], str]:
if architectures is None:
architectures = getattr(model_config.hf_config, "architectures", [])
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])

# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
Expand Down
17 changes: 10 additions & 7 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,14 +444,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))

if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
Comment on lines +454 to +455
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worry about this prefix being correct now since in the model checkpoint on HF the weights are just at lm_head, and so we do the same when specifying the ignored module in compressed tensors https://huggingface.co/nm-testing/Qwen2-VL-2B-Instruct-FP8-dynamic/blob/8a9ad03741a56273d91cf71afbe9b5baa9509e17/config.json#L186

We could add this model to vllm/tests/models/decoder_only/vision_language/test_models.py to verify

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be handled by the weight mapper inside Qwen2-VL weight loading logic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen2 (language-only) is already being tested in language models tests.

else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
self.lm_head = PPMissingLayer()

self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
Expand Down
117 changes: 32 additions & 85 deletions vllm/model_executor/models/qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from functools import lru_cache
from functools import cached_property, lru_cache
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)

Expand All @@ -34,28 +34,19 @@
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import IntermediateTensors, SequenceData

from .interfaces import SupportsMultiModal, SupportsPP
from .utils import merge_multimodal_embeddings
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)

logger = init_logger(__name__)

_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}


# # === Audio Inputs === #
class Qwen2AudioInputs(TypedDict):
Expand Down Expand Up @@ -281,25 +272,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.quant_config = quant_config

self.language_model = Qwen2Model(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=prefix)
self.unpadded_vocab_size = config.text_config.vocab_size
if config.text_config.tie_word_embeddings:
self.lm_head = self.language_model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.text_config.vocab_size,
config.text_config.hidden_size,
quant_config=quant_config)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.text_config.vocab_size,
logit_scale)
self.sampler = get_sampler()
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
mgoin marked this conversation as resolved.
Show resolved Hide resolved
)

self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)

@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler

return get_sampler()

def _validate_and_reshape_mm_tensor(self,
mm_input: Union[torch.Tensor,
List[torch.Tensor]],
Expand Down Expand Up @@ -414,72 +403,30 @@ def forward(
multimodal_embeddings)
input_ids = None

hidden_states = self.language_model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)

def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
return self.language_model.sample(logits, sampling_metadata)

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if (self.config.text_config.tie_word_embeddings
and "lm_head.weight" in name):
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name or 'audio' in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
Loading
Loading