Skip to content

Commit

Permalink
[Model] Refactor Molmo weights loading to use AutoWeightsLoader (#10771)
Browse files Browse the repository at this point in the history
Signed-off-by: Isotr0py <[email protected]>
  • Loading branch information
Isotr0py authored Nov 30, 2024
1 parent 40bc242 commit 16ee07f
Showing 1 changed file with 111 additions and 102 deletions.
213 changes: 111 additions & 102 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from array import array
from dataclasses import dataclass
from functools import lru_cache, partial
from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict
from typing import Iterable, List, Mapping, Optional, Set, Tuple, TypedDict

import torch
from einops import rearrange
Expand Down Expand Up @@ -44,7 +44,8 @@
from vllm.transformers_utils.processor import get_processor

from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (get_vit_attn_backend,
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

Expand Down Expand Up @@ -720,6 +721,42 @@ def forward(
# image_features: (batch_size, num_image, num_patch, d_model)
return image_features

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()

for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


@support_torch_compile
class MolmoModel(nn.Module):
Expand Down Expand Up @@ -804,6 +841,28 @@ def forward(
hidden_states = self.norm(hidden_states)
return hidden_states

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()

for name, loaded_weight in weights:
if "gate_up_proj" in name:
up_proj, gate_proj = loaded_weight.chunk(2, dim=0)
loaded_weight = torch.cat([gate_proj, up_proj], dim=0)

if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


cached_get_processor = lru_cache(get_processor)

Expand Down Expand Up @@ -1200,103 +1259,53 @@ def sample(
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

params_mapping = [
("model.transformer.ln_f.weight", "model.norm.weight"),
("attn_out", "self_attn.o_proj"),
("att_proj", "self_attn.qkv_proj"),
("q_norm", "self_attn.q_norm"),
("k_norm", "self_attn.k_norm"),
("attn_norm", "input_layernorm"),
("ff_norm", "post_attention_layernorm"),
]

params_dict = dict(self.named_parameters(remove_duplicate=False))

embedding_weight = dict()
projector_weight = dict()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue

if "wte.embedding" in name:
embedding_weight["embedding"] = loaded_weight
continue

if "wte.new_embedding" in name:
embedding_weight["new_embedding"] = loaded_weight
continue

if "vision_backbone" in name:
if name.startswith("model"):
name = name[len("model."):]
if 'image_projector' in name:
if 'w1' in name:
projector_weight['gate_proj'] = loaded_weight
elif 'w3' in name:
projector_weight['up_proj'] = loaded_weight
elif 'w2' in name:
projector_weight['down_proj'] = loaded_weight
else:
raise ValueError(
f"Unexpected projector weight: {name}")
continue
else:
if "transformer.blocks" in name:
name = name.replace("transformer.blocks", "layers")

if "ff_proj" in name:
name = name.replace("ff_proj", "mlp.gate_up_proj")
assert 'weight' in name
up_weight, gate_weight = loaded_weight.chunk(2, dim=0)
loaded_weight = torch.cat([gate_weight, up_weight], dim=0)

elif "ff_out" in name:
if "layers" in name:
name = name.replace("ff_out", "mlp.down_proj")
else:
# lm head
name = name.replace("model.transformer.ff_out",
"lm_head")

else:
for (param_name, weight_name) in params_mapping:
if param_name in name:
name = name.replace(param_name, weight_name)
break

try:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
except KeyError:
raise ValueError(f"Unexpected weight: {name}") from None

weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

gate_up_proj_weight = torch.cat(
[projector_weight["gate_proj"], projector_weight["up_proj"]],
dim=0)
name = "vision_backbone.image_projector.gate_up_proj.weight"
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, gate_up_proj_weight)

down_proj_weight = projector_weight["down_proj"]
name = "vision_backbone.image_projector.down_proj.weight"
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, down_proj_weight)

embedding_weight = torch.cat(
[embedding_weight["embedding"], embedding_weight["new_embedding"]],
dim=0)
name = "model.embed_tokens.weight"
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, embedding_weight)
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
# vision backbone mapping
"image_projector.w1.": "image_projector.gate_proj.",
"image_projector.w3.": "image_projector.up_proj.",
"image_projector.w2.": "image_projector.down_proj.",
# language backbone mapping
"att_proj": "self_attn.qkv_proj",
"attn_out": "self_attn.o_proj",
"q_norm": "self_attn.q_norm",
"k_norm": "self_attn.k_norm",
"ff_proj": "mlp.gate_up_proj",
"ff_out": "mlp.down_proj",
"attn_norm": "input_layernorm",
"ff_norm": "post_attention_layernorm",
},
orig_to_new_prefix={
# vision backbone mapping
"model.vision_backbone.": "vision_backbone.",
# language backbone mapping
"model.transformer.blocks.": "model.layers.",
"model.transformer.ln_f.": "model.norm.",
# lm_head is renamed to model.transformer.mlp.down_proj firstly,
# we need to run a second renaming for it
"model.transformer.mlp.down_proj.": "lm_head.",
},
)
loader = AutoWeightsLoader(self)
weights = _get_weights_with_merged_embedding(weights)
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)


def _get_weights_with_merged_embedding(
weights: Iterable[Tuple[str, torch.Tensor]]
) -> Iterable[Tuple[str, torch.Tensor]]:
embedding_weights = {}
for name, weight in weights:
if "wte.embedding" in name:
embedding_weights["embedding"] = weight
elif "wte.new_embedding" in name:
embedding_weights["new_embedding"] = weight
else:
yield (name, weight)
# this is compatible with most of quantization,
# because they won't quantize embed_tokens
embedding_weights = torch.cat(
[embedding_weights["embedding"], embedding_weights["new_embedding"]],
dim=0,
)
yield ("model.embed_tokens.weight", embedding_weights)

0 comments on commit 16ee07f

Please sign in to comment.