diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index c6786c363ab4a..94a4ab882c1a9 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -1,6 +1,5 @@ from dataclasses import dataclass, fields from functools import cached_property -from itertools import tee from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union import numpy @@ -359,38 +358,33 @@ def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]): def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]): return weight[0].startswith("vision_language_adapter") - def is_vision_weights(weight: Tuple[str, torch.Tensor]): - return is_vision_encoder_weights( - weight) or is_vision_lang_adapter_weights(weight) - - llm_weights, vision_encoder_weights, vision_lang_adapter_weights = tee( - weights, 3) - - # llm - llm_weights = filter(lambda x: not is_vision_weights(x), llm_weights) - self.language_model.load_weights(llm_weights) - - # vision encoder - vision_encoder_weights = filter(is_vision_encoder_weights, - vision_encoder_weights) + # Get references to parameters for direct loading vision_encoder_dict = dict(self.vision_encoder.named_parameters()) - for name, loaded_weight in vision_encoder_weights: - # cut 'vision_encoder.' - name = '.'.join(name.split(".")[1:]) - param = vision_encoder_dict[name] - - default_weight_loader(param, loaded_weight) - - # adapter - vision_lang_adapter_weights = filter(is_vision_lang_adapter_weights, - vision_lang_adapter_weights) - vision_lang_adpter_dict = dict( + vision_lang_adapter_dict = dict( self.vision_language_adapter.named_parameters()) - for name, loaded_weight in vision_lang_adapter_weights: - # cut 'vision_language_adapter.' - name = '.'.join(name.split(".")[1:]) - param = vision_lang_adpter_dict[name] - default_weight_loader(param, loaded_weight) + + def llm_weights_generator(): + # Single pass over weights + for name, w in weights: + if is_vision_encoder_weights((name, w)): + # Load vision encoder weights directly + trimmed_name = '.'.join(name.split(".")[1:]) + param = vision_encoder_dict[trimmed_name] + with torch.no_grad(): + default_weight_loader(param, w) + elif is_vision_lang_adapter_weights((name, w)): + # Load vision-language adapter weights directly + trimmed_name = '.'.join(name.split(".")[1:]) + param = vision_lang_adapter_dict[trimmed_name] + with torch.no_grad(): + default_weight_loader(param, w) + else: + # LLM weights: yield them to be loaded + # by language_model.load_weights + yield (name, w) + + # Now we call the language model load with the generator + self.language_model.load_weights(llm_weights_generator()) # Vision encoder