Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/support qwenvl glm4-v phi3-v(conflict resolving) #4377

Closed
wants to merge 41 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
fbf19f8
Basic support for webui.
marko1616 Jun 19, 2024
95b8a1d
Basic support for GLM4V
marko1616 Jun 19, 2024
61a0880
Merge branch 'hiyouga:main' into feature/Support-Qwenvl
marko1616 Jun 19, 2024
8044804
Pass ruff check.
marko1616 Jun 19, 2024
c58be83
Half of sft support and bug fix.
marko1616 Jun 20, 2024
4b01584
GLM4v lora sft support
marko1616 Jun 21, 2024
c233520
Little fix
marko1616 Jun 22, 2024
078c85d
Merge branch 'main' into feature/Support-Qwenvl
hiyouga Jun 24, 2024
67542a0
Fix requirements.txt
marko1616 Jun 25, 2024
e6aa967
fix conflict
BUAADreamer Jun 28, 2024
f698b43
QwenVL sft & webui train buxfix.
marko1616 Jun 29, 2024
3fa3a0b
phi3v infer support & rename.
marko1616 Jun 30, 2024
06823f4
Add rm,pt,ppo,kto,dpo support for glm4v(Not tested).
marko1616 Jun 30, 2024
40e817c
Merge branch 'hiyouga:main' into feature/Support-Qwenvl
marko1616 Jun 30, 2024
4e4f959
little fix
marko1616 Jun 30, 2024
4f564a1
Pass ruff
marko1616 Jun 30, 2024
5065e87
Merge branch 'main' into feature/Support-Qwenvl
marko1616 Jun 30, 2024
c37465e
Style check & fix requirements.txt
marko1616 Jul 1, 2024
9e7bb3f
Bugfix
marko1616 Jul 2, 2024
17e5d7d
Merge branch 'main' into feature/Support-Qwenvl
marko1616 Jul 2, 2024
5fe2862
Change implementation.
marko1616 Jul 2, 2024
e871b03
Merge remote
marko1616 Jul 2, 2024
b8cf95a
Update README, fix template constant, and add download source for phi3v.
Jul 2, 2024
c4ac67a
Merge pull request #1 from Radeon-grapchis/feature/Support-Qwenvl
marko1616 Jul 2, 2024
e6099f5
Name style fix.
marko1616 Jul 2, 2024
eb38fe2
modify glm_4v 9B desc
BUAADreamer Jul 3, 2024
51931b9
add torchvision to pass test
BUAADreamer Jul 3, 2024
a0ad0b5
modify dict in common
BUAADreamer Jul 3, 2024
3acefbc
Support latest glm4v.
marko1616 Jul 3, 2024
4146242
Phi3v lora sft fix.
marko1616 Jul 3, 2024
70ac8ea
fix get_template.
marko1616 Jul 3, 2024
ea60231
Update for unsupervised dataset.
marko1616 Jul 4, 2024
b932bc0
Phi3v dataset processor fix.
marko1616 Jul 6, 2024
36932dd
Merge branch 'main' into feature/Support-Qwenvl
marko1616 Jul 18, 2024
3c2ecba
Conflict fix
marko1616 Jul 18, 2024
3f9ccb3
RLHF support.
marko1616 Jul 19, 2024
9c6587e
glm4v pairwise dataset support
marko1616 Jul 19, 2024
cfe0652
Merge branch 'main' into feature/Support-Qwenvl
marko1616 Jul 31, 2024
19a4cf7
Merge branch 'main' into feature/Support-Qwenvl
marko1616 Aug 20, 2024
e9d902b
Name fix.
marko1616 Aug 22, 2024
65b64be
ruff pass.
marko1616 Aug 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions src/llamafactory/chat/hf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@
import asyncio
import concurrent.futures
import os
import pathlib
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union

import torch
import torchvision
from PIL import Image
from transformers import GenerationConfig, TextIteratorStreamer

from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger
from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer
from ..webui.common import DEFAULT_CACHE_DIR
from .base_engine import BaseEngine, Response


Expand Down Expand Up @@ -58,6 +62,7 @@ def __init__(
self.model = load_model(
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
) # must after fixing tokenizer to resize vocab
self.model_args = model_args
self.generating_args = generating_args.to_dict()
try:
asyncio.get_event_loop()
Expand All @@ -75,6 +80,7 @@ def _process_args(
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
model_args: "ModelArguments",
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
Expand All @@ -86,16 +92,26 @@ def _process_args(
and image is not None
and not hasattr(processor, "image_seq_length")
and template.image_token not in messages[0]["content"]
): # llava-like models
and model_args.visual_inputs_type == "vision_tower"
):
# llava-like models
messages[0]["content"] = template.image_token + messages[0]["content"]
elif image is not None and model_args.visual_inputs_type == "vision_token":
# Add image pathlike token as vision input
image_path = pathlib.Path(DEFAULT_CACHE_DIR) / "temp.png"
marko1616 marked this conversation as resolved.
Show resolved Hide resolved
Image.fromarray(image).convert("RGB").save(image_path)
messages[-1]["content"] = template.format_image.apply(content=os.fspath(image_path))[0] + messages[-1]["content"]
elif image is not None and model_args.visual_inputs_type == "vision_message_embed":
messages[-1]["content"] = template.format_image.apply()[0] + messages[-1]["content"]
marko1616 marked this conversation as resolved.
Show resolved Hide resolved

paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"]
pixel_values = None
prompt_ids, _ = template.encode_oneturn(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
)
if processor is not None and image is not None: # add image features
# add image features for vision tower
if processor is not None and image is not None and template.format_image is None:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
batch_feature = image_processor(image, return_tensors="pt")
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
Expand Down Expand Up @@ -163,6 +179,17 @@ def _process_args(
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor(),
)
if image is not None and model_args.visual_inputs_type == "vision_message_embed":
transform = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(
(1120, 1120), interpolation=torchvision.transforms.InterpolationMode.BICUBIC
),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
]
)
gen_kwargs["images"] = transform(Image.fromarray(image)).unsqueeze(0).to(model.device).to(model_args.compute_dtype)

if pixel_values is not None:
gen_kwargs["pixel_values"] = pixel_values
Expand All @@ -177,14 +204,15 @@ def _chat(
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
model_args: "ModelArguments",
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
model, tokenizer, processor, template, generating_args, model_args, messages, system, tools, image, input_kwargs
)
generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
Expand Down Expand Up @@ -212,14 +240,15 @@ def _stream_chat(
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
model_args: "ModelArguments",
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
model, tokenizer, processor, template, generating_args, model_args, messages, system, tools, image, input_kwargs
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
Expand Down Expand Up @@ -285,6 +314,7 @@ async def chat(
self.processor,
self.template,
self.generating_args,
self.model_args,
messages,
system,
tools,
Expand Down Expand Up @@ -313,6 +343,7 @@ async def stream_chat(
self.processor,
self.template,
self.generating_args,
self.model_args,
messages,
system,
tools,
Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/chat/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
"max_lora_rank": model_args.vllm_max_lora_rank,
}

if model_args.visual_inputs:
if model_args.visual_inputs and model_args.visual_inputs_type == "vision_tower":
marko1616 marked this conversation as resolved.
Show resolved Hide resolved
image_size = config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.image_feature_size = (image_size // patch_size) ** 2
Expand Down
6 changes: 5 additions & 1 deletion src/llamafactory/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def get_dataset(

with training_args.main_process_first(desc="pre-process dataset"):
preprocess_func, print_function = get_preprocess_and_print_func(
data_args, training_args, stage, template, tokenizer, processor
data_args, training_args, model_args, stage, template, tokenizer, processor
)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
Expand All @@ -190,6 +190,10 @@ def get_dataset(

dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)

if model_args.visual_inputs_type == "vision_message_embed":
dataset = dataset.rename_column("image_inputs","images")
print(dataset["images"])

marko1616 marked this conversation as resolved.
Show resolved Hide resolved
if data_args.tokenized_path is not None:
if training_args.should_save:
dataset.save_to_disk(data_args.tokenized_path)
Expand Down
4 changes: 3 additions & 1 deletion src/llamafactory/data/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments

from ..hparams import DataArguments
from ..hparams import DataArguments, ModelArguments
from .template import Template


def get_preprocess_and_print_func(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
model_args: "ModelArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
Expand Down Expand Up @@ -63,6 +64,7 @@ def get_preprocess_and_print_func(
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
model_args=model_args,
)

print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
Expand Down
21 changes: 17 additions & 4 deletions src/llamafactory/data/processors/processor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import bisect
from typing import TYPE_CHECKING, List, Sequence
from torchvision import transforms

from ...extras.packages import is_pillow_available

Expand Down Expand Up @@ -61,13 +62,25 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
return knapsacks


def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin", vision_type: str = "vision_tower") -> "NDArray":
r"""
Processes visual inputs. (currently only supports a single image)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
if vision_type == "vision_tower":
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
elif vision_type == "vision_message_embed":
transform = transforms.Compose(
[
transforms.Resize(
(1120, 1120), interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
]
)
return transform(images[0]) if len(images) != 0 else transform(Image.new("RGB", (1120, 1120), (255, 255, 255)))
marko1616 marked this conversation as resolved.
Show resolved Hide resolved


def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
Expand Down
13 changes: 10 additions & 3 deletions src/llamafactory/data/processors/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin

from ...hparams import DataArguments
from ...hparams import DataArguments, ModelArguments
from ..template import Template


Expand Down Expand Up @@ -78,19 +78,26 @@ def preprocess_supervised_dataset(
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
model_args: "ModelArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
if processor is not None and model_args.visual_inputs_type == "vision_tower":
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
elif model_args.visual_inputs_type == "vision_message_embed":
model_inputs["image_inputs"] = []

for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
if model_args.visual_inputs_type == "vision_message_embed":
assert len(examples["images"][i]) <= 1,"GLM4v only support 1 image train yet."
model_inputs["image_inputs"].append(get_pixel_values(examples["images"][i], None, "vision_message_embed"))
examples["prompt"][i][-1]["content"] = template.format_image.apply()[0] + examples["prompt"][i][-1]["content"]

input_ids, labels = _encode_supervised_example(
prompt=examples["prompt"][i],
Expand All @@ -105,7 +112,7 @@ def preprocess_supervised_dataset(
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None:
if processor is not None and model_args.visual_inputs_type == "vision_tower":
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
Expand Down
31 changes: 31 additions & 0 deletions src/llamafactory/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class Template:
format_tools: "Formatter"
format_separator: "Formatter"
format_prefix: "Formatter"
format_image: "Formatter"
default_system: str
stop_words: List[str]
image_token: str
Expand Down Expand Up @@ -239,6 +240,7 @@ def _register_template(
format_tools: Optional["Formatter"] = None,
format_separator: Optional["Formatter"] = None,
format_prefix: Optional["Formatter"] = None,
format_image: Optional["Formatter"] = None,
default_system: str = "",
stop_words: List[str] = [],
image_token: str = "<image>",
Expand Down Expand Up @@ -290,6 +292,7 @@ def _register_template(
format_tools=format_tools or default_tool_formatter,
format_separator=format_separator or default_separator_formatter,
format_prefix=format_prefix or default_prefix_formatter,
format_image=format_image,
default_system=default_system,
stop_words=stop_words,
image_token=image_token,
Expand Down Expand Up @@ -686,6 +689,21 @@ def get_template_and_fix_tokenizer(
)


_register_template(
name="glm4v",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
format_image=EmptyFormatter(slots=["<|begin_of_image|><|endoftext|><|end_of_image|>"]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
)


_register_template(
name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
Expand Down Expand Up @@ -815,6 +833,19 @@ def get_template_and_fix_tokenizer(
)


_register_template(
name="qwenvl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_image=StringFormatter(slots=["<img>{{content}}</img>"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
)


_register_template(
name="solar",
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
Expand Down
Loading