Skip to content

Commit

Permalink
[Model] Rename MiniCPMVQwen2 to MiniCPMV2.6 (#7273)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeejeelee authored Aug 8, 2024
1 parent 6dffa4b commit 757ac70
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 31 deletions.
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ Vision Language Models
-
* - :code:`MiniCPMV`
- MiniCPM-V
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc.
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
-

.. note::
Expand Down
51 changes: 35 additions & 16 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,26 @@ def run_llava(question):
prompt = f"USER: <image>\n{question}\nASSISTANT:"

llm = LLM(model="llava-hf/llava-1.5-7b-hf")

return llm, prompt
stop_token_ids = None
return llm, prompt, stop_token_ids


# LLaVA-1.6/LLaVA-NeXT
def run_llava_next(question):

prompt = f"[INST] <image>\n{question} [/INST]"
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf")

return llm, prompt
stop_token_ids = None
return llm, prompt, stop_token_ids


# Fuyu
def run_fuyu(question):

prompt = f"{question}\n"
llm = LLM(model="adept/fuyu-8b")

return llm, prompt
stop_token_ids = None
return llm, prompt, stop_token_ids


# Phi-3-Vision
Expand All @@ -59,7 +59,8 @@ def run_phi3v(question):
trust_remote_code=True,
max_num_seqs=5,
)
return llm, prompt
stop_token_ids = None
return llm, prompt, stop_token_ids


# PaliGemma
Expand All @@ -68,16 +69,17 @@ def run_paligemma(question):
# PaliGemma has special prompt format for VQA
prompt = "caption en"
llm = LLM(model="google/paligemma-3b-mix-224")

return llm, prompt
stop_token_ids = None
return llm, prompt, stop_token_ids


# Chameleon
def run_chameleon(question):

prompt = f"{question}<image>"
llm = LLM(model="facebook/chameleon-7b")
return llm, prompt
stop_token_ids = None
return llm, prompt, stop_token_ids


# MiniCPM-V
Expand All @@ -89,13 +91,26 @@ def run_minicpmv(question):
# model_name = "HwwwH/MiniCPM-V-2"

# 2.5
model_name = "openbmb/MiniCPM-Llama3-V-2_5"
# model_name = "openbmb/MiniCPM-Llama3-V-2_5"

#2.6
model_name = "openbmb/MiniCPM-V-2_6"
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
llm = LLM(
model=model_name,
trust_remote_code=True,
)
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
# 2.0
# stop_token_ids = [tokenizer.eos_id]

# 2.5
# stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]

# 2.6
stop_tokens = ['<|im_end|>', '<|endoftext|>']
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]

messages = [{
'role': 'user',
Expand All @@ -104,7 +119,7 @@ def run_minicpmv(question):
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
return llm, prompt
return llm, prompt, stop_token_ids


# InternVL
Expand All @@ -118,7 +133,8 @@ def run_internvl(question):
trust_remote_code=True,
max_num_seqs=5,
)
return llm, prompt
stop_token_ids = None
return llm, prompt, stop_token_ids


# BLIP-2
Expand All @@ -128,7 +144,8 @@ def run_blip2(question):
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
prompt = f"Question: {question} Answer:"
llm = LLM(model="Salesforce/blip2-opt-2.7b")
return llm, prompt
stop_token_ids = None
return llm, prompt, stop_token_ids


model_example_map = {
Expand All @@ -149,11 +166,13 @@ def main(args):
if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.")

llm, prompt = model_example_map[model](question)
llm, prompt, stop_token_ids = model_example_map[model](question)

# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2, max_tokens=64)
sampling_params = SamplingParams(temperature=0.2,
max_tokens=64,
stop_token_ids=stop_token_ids)

assert args.num_prompts > 0
if args.num_prompts == 1:
Expand Down
29 changes: 15 additions & 14 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ def __init__(

self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=0.02)

if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
else:
Expand All @@ -225,7 +224,6 @@ def __init__(
nn.Identity()(*args, **kwargs),
None,
)

self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
Expand Down Expand Up @@ -261,7 +259,6 @@ def __init__(
norm_layer)

self.adaptive = adaptive

pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
grid_size,
version=(2, 0))
Expand Down Expand Up @@ -717,7 +714,7 @@ def is_default_weight_loading(self, name: str) -> bool:
raise NotImplementedError


class MiniCPMV2(MiniCPMVBaseModel):
class MiniCPMV2_0(MiniCPMVBaseModel):

def __init__(
self,
Expand Down Expand Up @@ -890,10 +887,7 @@ def is_default_weight_loading(self, name: str) -> bool:
return "resampler" in name


# NOTE: Currently, information about this model is unavailable. We are
# temporarily using `MiniCPMVQwen2` as it's name. The name may need
# to be modified in the future.
class MiniCPMVQwen2(MiniCPMVBaseModel):
class MiniCPMV2_6(MiniCPMVBaseModel):

def __init__(
self,
Expand All @@ -903,6 +897,7 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(config, multimodal_config, cache_config, quant_config)
assert self.version == (2, 6)

def init_llm(
self,
Expand Down Expand Up @@ -930,6 +925,7 @@ def init_vision_module(self) -> nn.Module:

def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
with set_default_torch_dtype(torch.float16):
# The resampler in 2.6 remains consistent with the one in 2.5.
resampler = Resampler2_5(
num_queries=self.config.query_num,
embed_dim=embed_dim,
Expand Down Expand Up @@ -989,6 +985,13 @@ def is_default_weight_loading(self, name: str) -> bool:
return "resampler" in name or "vpm" in name


_SUPPORT_VERSION = {
(2, 0): MiniCPMV2_0,
(2, 5): MiniCPMV2_5,
(2, 6): MiniCPMV2_6
}


@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
Expand Down Expand Up @@ -1016,11 +1019,9 @@ def __new__(
version = str(config.version).split(".")
version = tuple([int(x) for x in version])
# Dispatch class based on version
if version == (2, 0):
instance_class = MiniCPMV2
elif version == (2, 5):
instance_class = MiniCPMV2_5
else:
instance_class = MiniCPMVQwen2
instance_class = _SUPPORT_VERSION.get(version, None)
if instance_class is None:
raise ValueError(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
return instance_class(config, multimodal_config, cache_config,
quant_config)

0 comments on commit 757ac70

Please sign in to comment.