Skip to content

Commit

Permalink
[Pixtral] Improve loading (#11040)
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten authored Dec 10, 2024
1 parent 980ad39 commit bc192a2
Showing 1 changed file with 25 additions and 31 deletions.
56 changes: 25 additions & 31 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from dataclasses import dataclass, fields
from functools import cached_property
from itertools import tee
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union

import numpy
Expand Down Expand Up @@ -359,38 +358,33 @@ def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]):
def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]):
return weight[0].startswith("vision_language_adapter")

def is_vision_weights(weight: Tuple[str, torch.Tensor]):
return is_vision_encoder_weights(
weight) or is_vision_lang_adapter_weights(weight)

llm_weights, vision_encoder_weights, vision_lang_adapter_weights = tee(
weights, 3)

# llm
llm_weights = filter(lambda x: not is_vision_weights(x), llm_weights)
self.language_model.load_weights(llm_weights)

# vision encoder
vision_encoder_weights = filter(is_vision_encoder_weights,
vision_encoder_weights)
# Get references to parameters for direct loading
vision_encoder_dict = dict(self.vision_encoder.named_parameters())
for name, loaded_weight in vision_encoder_weights:
# cut 'vision_encoder.'
name = '.'.join(name.split(".")[1:])
param = vision_encoder_dict[name]

default_weight_loader(param, loaded_weight)

# adapter
vision_lang_adapter_weights = filter(is_vision_lang_adapter_weights,
vision_lang_adapter_weights)
vision_lang_adpter_dict = dict(
vision_lang_adapter_dict = dict(
self.vision_language_adapter.named_parameters())
for name, loaded_weight in vision_lang_adapter_weights:
# cut 'vision_language_adapter.'
name = '.'.join(name.split(".")[1:])
param = vision_lang_adpter_dict[name]
default_weight_loader(param, loaded_weight)

def llm_weights_generator():
# Single pass over weights
for name, w in weights:
if is_vision_encoder_weights((name, w)):
# Load vision encoder weights directly
trimmed_name = '.'.join(name.split(".")[1:])
param = vision_encoder_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_vision_lang_adapter_weights((name, w)):
# Load vision-language adapter weights directly
trimmed_name = '.'.join(name.split(".")[1:])
param = vision_lang_adapter_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
else:
# LLM weights: yield them to be loaded
# by language_model.load_weights
yield (name, w)

# Now we call the language model load with the generator
self.language_model.load_weights(llm_weights_generator())


# Vision encoder
Expand Down

0 comments on commit bc192a2

Please sign in to comment.