Skip to content

Commit

Permalink
Merge pull request #6251 from hiyouga/hiyouga/vllm_qwen2vl_infer
Browse files Browse the repository at this point in the history
[infer] support qwen2vl vllm infer
  • Loading branch information
hiyouga authored Dec 5, 2024
2 parents 967a6c1 + 207f8b0 commit 561a8e5
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 69 deletions.
6 changes: 4 additions & 2 deletions scripts/stat_utils/cal_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling

from llamafactory.data import get_dataset, get_template_and_fix_tokenizer, MultiModalDataCollatorForSeq2Seq
from llamafactory.data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer
Expand Down Expand Up @@ -71,7 +71,9 @@ def calculate_lr(
if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft":
data_collator = MultiModalDataCollatorForSeq2Seq(template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
data_collator = MultiModalDataCollatorForSeq2Seq(
template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX
)
else:
raise NotImplementedError(f"Stage does not supported: {stage}.")

Expand Down
40 changes: 33 additions & 7 deletions scripts/vllm_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,25 @@

import fire
from transformers import Seq2SeqTrainingArguments
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.extras.misc import get_device_count
from llamafactory.extras.packages import is_pillow_available, is_vllm_available
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer


if is_pillow_available():
from PIL import Image
from PIL.Image import Image as ImageObject


if is_vllm_available():
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest


def vllm_infer(
model_name_or_path: str,
adapter_name_or_path: str = None,
Expand Down Expand Up @@ -64,15 +73,29 @@ def vllm_infer(
)
)

training_args = Seq2SeqTrainingArguments(output_dir="dummy_dir", predict_with_generate=True)
training_args = Seq2SeqTrainingArguments(output_dir="dummy_dir")
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset = get_dataset(template, model_args, data_args, training_args, "ppo", **tokenizer_module)["train_dataset"]
template_obj = get_template_and_fix_tokenizer(tokenizer, data_args)
template_obj.mm_plugin.expand_mm_tokens = False # for vllm generate
dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module)

inputs, prompts, labels = [], [], []
for sample in dataset:
inputs.append({"prompt_token_ids": sample["input_ids"]})
for sample in dataset_module["train_dataset"]:
if sample["images"]:
multi_modal_data = {"image": []}
for image in sample["images"]:
if not isinstance(image, (str, ImageObject)):
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")

if isinstance(image, str):
image = Image.open(image).convert("RGB")

multi_modal_data["image"].append(image)
else:
multi_modal_data = None

inputs.append({"prompt_token_ids": sample["input_ids"], "multi_modal_data": multi_modal_data})
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=False))
labels.append(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=False)
Expand Down Expand Up @@ -100,6 +123,9 @@ def vllm_infer(
"disable_log_stats": True,
"enable_lora": model_args.adapter_name_or_path is not None,
}
if template_obj.mm_plugin.__class__.__name__ != "BasePlugin":
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}

if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config)

Expand Down
30 changes: 17 additions & 13 deletions src/llamafactory/chat/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.misc import get_device_count
from ..extras.packages import is_pillow_available, is_vllm_available
from ..model import load_config, load_tokenizer
Expand Down Expand Up @@ -67,6 +67,7 @@ def __init__(
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.template.mm_plugin.expand_mm_tokens = False # for vllm generate
self.generating_args = generating_args.to_dict()

engine_args = {
Expand All @@ -83,6 +84,9 @@ def __init__(
"enable_lora": model_args.adapter_name_or_path is not None,
"max_lora_rank": model_args.vllm_max_lora_rank,
}
if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}

if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config)

Expand All @@ -108,19 +112,21 @@ async def _generate(
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = f"chatcmpl-{uuid.uuid4().hex}"
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
if images is not None:
mm_input_dict.update({"images": images, "imglens": [len(images)]})
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]

if self.template.mm_plugin.__class__.__name__ == "Qwen2vlPlugin": # temporary solution
image_str = f"<|vision_start|>{self.template.mm_plugin.image_token}<|vision_end|>"
else:
image_str = self.template.mm_plugin.image_token or ""
if videos is not None:
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]

paired_messages = [
{"role": message["role"], "content": message["content"].replace(IMAGE_PLACEHOLDER, image_str)}
for message in messages
] + [{"role": "assistant", "content": ""}]
messages = self.template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], self.processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)
Expand Down Expand Up @@ -168,17 +174,15 @@ async def _generate(
)

if images is not None: # add image features
image_data = []
multi_modal_data = {"image": []}
for image in images:
if not isinstance(image, (str, ImageObject)):
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")

if isinstance(image, str):
image = Image.open(image).convert("RGB")

image_data.append(image)

multi_modal_data = {"image": image_data}
multi_modal_data["image"].append(image)
else:
multi_modal_data = None

Expand Down
118 changes: 71 additions & 47 deletions src/llamafactory/data/mm_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class BasePlugin:
def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
self.image_token = image_token
self.video_token = video_token
self.expand_mm_tokens = True

def _validate_input(
self,
Expand Down Expand Up @@ -259,7 +260,7 @@ def process_messages(
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
image_seqlen = getattr(processor, "image_seqlen")
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
messages = deepcopy(messages)
for message in messages:
content = message["content"]
Expand Down Expand Up @@ -310,11 +311,13 @@ def process_messages(
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
image_size = next(image_sizes)
orig_height, orig_width = image_size
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
else:
image_seqlen = 1

num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
Expand Down Expand Up @@ -359,11 +362,13 @@ def process_messages(
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
image_size = next(image_sizes)
orig_height, orig_width = image_size
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
else:
image_seqlen = 1

num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
Expand All @@ -376,6 +381,7 @@ def process_messages(
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
video_seqlen = video_seqlen if self.expand_mm_tokens else 1
for message in messages:
content = message["content"]
while VIDEO_PLACEHOLDER in content:
Expand Down Expand Up @@ -443,7 +449,7 @@ def process_token_ids(
) -> Tuple[List[int], Optional[List[int]]]:
self._validate_input(images, videos)
num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen")
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
input_ids = [image_token_id] * image_seqlen + input_ids
if labels is not None:
Expand Down Expand Up @@ -493,14 +499,18 @@ def process_messages(
if image_input_sizes is None:
raise ValueError("Cannot get image input sizes.")

image_size = image_input_sizes[0][num_image_tokens]
height, width = image_size
num_height_tokens = height // patch_size
num_width_tokens = width // patch_size
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
replace_tokens[-1] = image_end_token
replace_str = "".join(replace_tokens)
if self.expand_mm_tokens:
image_size = image_input_sizes[0][num_image_tokens]
height, width = image_size
num_height_tokens = height // patch_size
num_width_tokens = width // patch_size
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
replace_tokens[-1] = image_end_token
replace_str = "".join(replace_tokens)
else:
replace_str = image_token

content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
num_image_tokens += 1

Expand Down Expand Up @@ -549,10 +559,27 @@ def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
return image

@override
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
sample_frames = super()._get_video_sample_frames(video_stream, **kwargs)
sample_frames = sample_frames // 2 * 2
return sample_frames
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
results = []
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
total_frames = video_stream.frames
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
frames: List["ImageObject"] = []
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
frames.append(frame.to_image())

if len(frames) % 2 != 0: # qwen2-vl requires even number of frames
frames.append(frames[-1])

frames = self._regularize_images(frames, **kwargs)
results.append(frames)

return results

@override
def process_messages(
Expand All @@ -577,25 +604,19 @@ def process_messages(
if num_image_tokens >= len(image_grid_thw):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")

image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER,
"<|vision_start|>{}<|vision_end|>".format(
self.image_token * (image_grid_thw[num_image_tokens].prod() // merge_length)
),
1,
IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1
)
num_image_tokens += 1

while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_grid_thw):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")

video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
VIDEO_PLACEHOLDER,
"<|vision_start|>{}<|vision_end|>".format(
self.video_token * (video_grid_thw[num_video_tokens].prod() // merge_length)
),
1,
VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1
)
num_video_tokens += 1

Expand Down Expand Up @@ -640,19 +661,22 @@ def process_messages(
has_images = "pixel_values_images" in mm_inputs
has_videos = "pixel_values_videos" in mm_inputs
if has_images or has_videos:
if has_images:
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
num_frames = 1

if has_videos:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim

image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
video_seqlen = image_seqlen * num_frames
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
if self.expand_mm_tokens:
if has_images:
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
num_frames = 1

if has_videos:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim

image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
video_seqlen = image_seqlen * num_frames
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
else:
image_seqlen, video_seqlen = 1, 1

for message in messages:
content = message["content"]
Expand Down

0 comments on commit 561a8e5

Please sign in to comment.