Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add qwen2 support for pretraining and finetuning #1573

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llava/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .model import LlavaLlamaForCausalLM
from .model import LlavaLlamaForCausalLM, LlavaQwen2ForCausalLM
63 changes: 56 additions & 7 deletions llava/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class SeparatorStyle(Enum):
MPT = auto()
PLAIN = auto()
LLAMA_2 = auto()
QWEN_2 = auto() # fix: add qwen2
CHATML = auto()


@dataclasses.dataclass
Expand Down Expand Up @@ -51,6 +53,27 @@ def get_prompt(self):
ret += role + ": " + message + self.sep
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.QWEN_2: # fix: add qwen2
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.CHATML:
ret = "" if self.system == "" else self.system + self.sep + "\n"
for role, message in messages:
if message:
if type(message) is tuple:
message, images = message
message = "<image>" * len(images) + message
ret += role + "\n" + message + self.sep + "\n"
else:
ret += role + "\n"
return ret
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
Expand All @@ -71,8 +94,8 @@ def get_prompt(self):
else:
ret += role
elif self.sep_style == SeparatorStyle.LLAMA_2:
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
def wrap_sys(msg): return f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
def wrap_inst(msg): return f"[INST] {msg} [/INST]"
ret = ""

for i, (role, message) in enumerate(messages):
Expand All @@ -82,7 +105,8 @@ def get_prompt(self):
if message:
if type(message) is tuple:
message, _, _ = message
if i == 0: message = wrap_sys(self.system) + message
if i == 0:
message = wrap_sys(self.system) + message
if i % 2 == 0:
message = wrap_inst(message)
ret += self.sep + message
Expand Down Expand Up @@ -369,12 +393,38 @@ def dict(self):
sep="<|im_end|>",
)

default_conversation = conv_vicuna_v1

# fix: add qwen2
conv_qwen_2 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="qwen_v2",
messages=(),
offset=0,
sep_style=SeparatorStyle.QWEN_2,
sep=" ",
sep2="<|endoftext|>",
)

# conv_qwen_2 = Conversation(
# system="""<|im_start|>system
# You are a helpful assistant.""",
# roles=("<|im_start|>user", "<|im_start|>assistant"),
# version="qwen_v2",
# messages=[],
# offset=0,
# sep_style=SeparatorStyle.CHATML,
# sep="<|im_end|>",
# )

default_conversation = conv_qwen_2
conv_templates = {
"default": conv_vicuna_v0,
"default": conv_qwen_2,
"v0": conv_vicuna_v0,
"v1": conv_vicuna_v1,
"vicuna_v1": conv_vicuna_v1,
"qwen_2": conv_qwen_2,
"llama_2": conv_llama_2,
"mistral_instruct": conv_mistral_instruct,
"chatml_direct": conv_chatml_direct,
Expand All @@ -391,6 +441,5 @@ def dict(self):
"mpt": conv_mpt,
}


if __name__ == "__main__":
print(default_conversation.get_prompt())
print("conversation:", default_conversation.get_prompt())
4 changes: 2 additions & 2 deletions llava/eval/model_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ def eval_model(model_name, questions_file, answers_file):
model_name = os.path.expanduser(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_name,
torch_dtype=torch.float16).cuda()

torch_dtype=torch.float16).cuda()

ques_file = open(os.path.expanduser(questions_file), "r")
ans_file = open(os.path.expanduser(answers_file), "w")
Expand Down Expand Up @@ -54,6 +53,7 @@ def eval_model(model_name, questions_file, answers_file):
ans_file.flush()
ans_file.close()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
Expand Down
1 change: 1 addition & 0 deletions llava/eval/model_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def eval_model(args):
ans_file.flush()
ans_file.close()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
Expand Down
1 change: 1 addition & 0 deletions llava/eval/model_vqa_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def eval_model(args):
# ans_file.flush()
ans_file.close()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
Expand Down
18 changes: 10 additions & 8 deletions llava/eval/model_vqa_mmbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def is_none(value):
return True
return False


def get_options(row, options):
parsed_options = []
for option in options:
Expand Down Expand Up @@ -124,21 +125,22 @@ def eval_model(args):

ans_id = shortuuid.uuid()
ans_file.write(json.dumps({"question_id": idx,
"round_id": round_idx,
"prompt": cur_prompt,
"text": outputs,
"options": options,
"option_char": cur_option_char,
"answer_id": ans_id,
"model_id": model_name,
"metadata": {}}) + "\n")
"round_id": round_idx,
"prompt": cur_prompt,
"text": outputs,
"options": options,
"option_char": cur_option_char,
"answer_id": ans_id,
"model_id": model_name,
"metadata": {}}) + "\n")
ans_file.flush()

# rotate options
options = options[1:] + options[:1]
cur_option_char = cur_option_char[1:] + cur_option_char[:1]
ans_file.close()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
Expand Down
1 change: 1 addition & 0 deletions llava/eval/model_vqa_science.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def eval_model(args):
ans_file.flush()
ans_file.close()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
Expand Down
5 changes: 3 additions & 2 deletions llava/mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def get_model_name_from_path(model_path):
else:
return model_paths[-1]


class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
Expand All @@ -226,7 +227,7 @@ def __init__(self, keywords, tokenizer, input_ids):
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]

def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
Expand All @@ -239,7 +240,7 @@ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor
if keyword in outputs:
return True
return False

def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
outputs = []
for i in range(output_ids.shape[0]):
Expand Down
1 change: 1 addition & 0 deletions llava/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig
from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
from .language_model.llava_qwen import LlavaQwen2ForCausalLM, LlavaConfig
except:
pass
1 change: 1 addition & 0 deletions llava/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
else:
# this is probably from HF Hub
from huggingface_hub import hf_hub_download

def load_from_hf(repo_id, filename, subfolder=None):
cache_file = hf_hub_download(
repo_id=repo_id,
Expand Down
28 changes: 14 additions & 14 deletions llava/model/language_model/llava_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch

from transformers import AutoConfig, AutoModelForCausalLM, \
MptConfig, MptForCausalLM, MptModel
MptConfig, MptForCausalLM, MptModel
from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM


Expand All @@ -32,7 +32,7 @@ class LlavaMptModel(LlavaMetaModel, MptModel):
def __init__(self, config: MptConfig):
config.hidden_size = config.d_model
super(LlavaMptModel, self).__init__(config)

def embed_tokens(self, x):
return self.wte(x)

Expand All @@ -58,20 +58,20 @@ def _set_gradient_checkpointing(self, module, value=False):
module.gradient_checkpointing = value

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
images=None):
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
images=None):

input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)

return super().forward(
input_ids,
past_key_values=past_key_values,
Expand Down
Loading