Skip to content

Commit

Permalink
[Model] Clean up MiniCPMV (#10751)
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored Nov 29, 2024
1 parent c83919c commit fa6ecb9
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 215 deletions.
19 changes: 16 additions & 3 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,16 +295,29 @@
)
],
),
"minicpmv": VLMTestInfo(
"minicpmv_25": VLMTestInfo(
models=["openbmb/MiniCPM-Llama3-V-2_5"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
test_type=VLMTestType.IMAGE,
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096,
max_num_seqs=2,
get_stop_token_ids=lambda tok: [tok.eos_id, tok.eot_id],
postprocess_inputs=model_utils.wrap_inputs_post_processor,
hf_output_post_proc=model_utils.minicmpv_trunc_hf_output,
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
),
"minicpmv_26": VLMTestInfo(
models=["openbmb/MiniCPM-V-2_6"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096,
max_num_seqs=2,
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
postprocess_inputs=model_utils.ignore_inputs_post_processor(
"image_sizes"
),
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
),
# Tests for phi3v currently live in another file because of a bug in
# transformers. Once this issue is fixed, we can enable them here instead.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput,


####### Post-processors for HF outputs
def minicmpv_trunc_hf_output(hf_output: RunnerOutput,
def minicpmv_trunc_hf_output(hf_output: RunnerOutput,
model: str) -> RunnerOutput:
output_ids, output_str, out_logprobs = hf_output
if output_str.endswith("<|eot_id|>"):
Expand All @@ -197,6 +197,17 @@ def process(hf_inputs: BatchEncoding, dtype: str):
return process


def ignore_inputs_post_processor(
hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]:
"""Gets a handle to a post processor which ignores a given key."""

def process(hf_inputs: BatchEncoding, dtype: str):
del hf_inputs[hf_inp_key]
return hf_inputs

return process


def wrap_inputs_post_processor(hf_inputs: BatchEncoding, dtype: str):
return {"model_inputs": hf_inputs}

Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _load_per_tensor_weight_scale(self, shard_id: str,
def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.tensor,
loaded_weight: torch.Tensor,
tp_rank: int):
# Load grouped weight scales for group quantization
# or model weights
Expand All @@ -261,7 +261,7 @@ def _load_model_weight_or_group_weight_scale(self, shard_dim: int,

def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
shard_dim: int, shard_id: str,
loaded_weight: torch.tensor,
loaded_weight: torch.Tensor,
tp_rank: int):
# for per channel weight quantization
if shard_id == "w2":
Expand All @@ -274,7 +274,7 @@ def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
tp_rank=tp_rank)

def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):

# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
Expand All @@ -292,7 +292,7 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
expert_data.copy_(loaded_weight)

def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):

# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
Expand All @@ -311,7 +311,7 @@ def _load_single_value(self, param: torch.nn.Parameter,
param_data[expert_id] = loaded_weight

def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
shard_dim: int, loaded_weight: torch.tensor, tp_rank: int):
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):

if shard_id == "w2":
self._load_w2(shard_id=shard_id,
Expand Down
153 changes: 79 additions & 74 deletions vllm/model_executor/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

Expand Down Expand Up @@ -378,6 +378,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.num_experts = getattr(self.config, "num_experts", 0)
self._init_layers(prefix, config, cache_config, quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
Expand Down Expand Up @@ -437,6 +438,73 @@ def forward(
hidden_states = self.norm(hidden_states)
return hidden_states

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = [
# (param_name, weight_name, expert_id)
("ws" if weight_name in ["w1", "w3"] else "w2s",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.num_experts)
for weight_name in ["w1", "w2", "w3"]
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for param_name, weight_name, expert_id in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
weight_name,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
Expand Down Expand Up @@ -480,8 +548,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.cache_config = cache_config
self.quant_config = quant_config

self.num_experts = getattr(self.config, "num_experts", 0)
self._init_model(vllm_config=vllm_config, prefix=prefix)
self.model = self._init_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))

unpadded_vocab_size = config.vocab_size
if lora_config:
unpadded_vocab_size += lora_config.lora_extra_vocab_size
Expand All @@ -506,8 +575,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model.make_empty_intermediate_tensors)

def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model = MiniCPMModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
return MiniCPMModel(vllm_config=vllm_config, prefix=prefix)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
Expand Down Expand Up @@ -546,72 +614,9 @@ def sample(

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = [
# (param_name, weight_name, expert_id)
("ws" if weight_name in ["w1", "w3"] else "w2s",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.num_experts)
for weight_name in ["w1", "w2", "w3"]
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for param_name, weight_name, expert_id in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
weight_name,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
5 changes: 2 additions & 3 deletions vllm/model_executor/models/minicpm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
MiniCPMForCausalLM,
MiniCPMModel)

from .utils import make_layers, maybe_prefix
from .utils import make_layers


class MiniCPM3Attention(nn.Module):
Expand Down Expand Up @@ -248,5 +248,4 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
}

def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model = MiniCPM3Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix)
Loading

0 comments on commit fa6ecb9

Please sign in to comment.