Skip to content

Commit

Permalink
[Feat] Add support for evaluation of InternVideo2-Chat && Fix evaluat…
Browse files Browse the repository at this point in the history
…ion for mvbench (#280)

* [add] add internvideo2 support && change mvbench to video branch

* [add] answer_prompt of internvideo2

* [add] change video type of internvideo2

* [fix] update template of mvbench

* [reformat]

* [fix] generate_until_multi_round

* [Feat] videochat2 support

---------

Co-authored-by: heyinan <[email protected]>
  • Loading branch information
yinanhe and heyinan authored Oct 2, 2024
1 parent 7c2d91c commit af395ae
Show file tree
Hide file tree
Showing 26 changed files with 933 additions and 55 deletions.
3 changes: 2 additions & 1 deletion lmms_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,8 @@ def _download_from_youtube(path):
if accelerator.is_main_process:
force_download = dataset_kwargs.get("force_download", False)
force_unzip = dataset_kwargs.get("force_unzip", False)
cache_path = snapshot_download(repo_id=self.DATASET_PATH, repo_type="dataset", force_download=force_download, etag_timeout=60)
revision = dataset_kwargs.get("revision", "main")
cache_path = snapshot_download(repo_id=self.DATASET_PATH, revision=revision, repo_type="dataset", force_download=force_download, etag_timeout=60)
zip_files = glob(os.path.join(cache_path, "**/*.zip"), recursive=True)
tar_files = glob(os.path.join(cache_path, "**/*.tar*"), recursive=True)

Expand Down
2 changes: 2 additions & 0 deletions lmms_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@
"video_llava": "VideoLLaVA",
"vila": "VILA",
"xcomposer2_4KHD": "XComposer2_4KHD",
"internvideo2": "InternVideo2",
"xcomposer2d5": "XComposer2D5",
"oryx": "Oryx",
"videochat2": "VideoChat2",
}


Expand Down
366 changes: 366 additions & 0 deletions lmms_eval/models/internvideo2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,366 @@
import logging
import os
from typing import List, Tuple

import decord
import numpy as np
import torch
import torchvision.transforms as T
from accelerate import Accelerator, DistributedType
from decord import VideoReader, cpu

decord.bridge.set_bridge("torch")
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model

eval_logger = logging.getLogger("eval_logger")


from datetime import timedelta

from accelerate.state import AcceleratorState
from accelerate.utils import InitProcessGroupKwargs

DEFAULT_GEN_KWARGS = dict(
num_beams=1,
max_new_tokens=1024,
do_sample=False,
)

# def get_index(num_frames, num_segments):
# seg_size = float(num_frames - 1) / num_segments
# start = int(seg_size / 2)
# offsets = np.array([
# start + int(np.round(seg_size * idx)) for idx in range(num_segments)
# ])
# return offsets


def get_index(max_frame, num_segments, fps, first_idx=0, bound=None):
if bound:
start, end = bound[0], bound[1]
if start is None:
start, end = -100000, 100000
else:
start, end = -100000, 100000
start_idx = max(first_idx, round(start * fps))
end_idx = min(round(end * fps), max_frame)
seg_size = float(end_idx - start_idx) / num_segments
frame_indices = np.array([int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments)])
return frame_indices


def load_image(image_path, resolution=224, hd_num=6):
image = Image.open(image_path).convert("RGB")
image_tensor = T.PILToTensor()(image).unsqueeze(0)
image_tensor = HD_transform_no_padding(image_tensor.float(), image_size=resolution, hd_num=hd_num)
T_, C, H, W = image_tensor.shape

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

transform = T.Compose([T.Lambda(lambda x: x.float().div(255.0)), T.Normalize(mean, std)])
image_tensor = transform(image_tensor).cuda()

sub_img = image_tensor.reshape(1, T_, 3, H // resolution, resolution, W // resolution, resolution).permute(0, 3, 5, 1, 2, 4, 6).reshape(-1, T_, 3, resolution, resolution).contiguous()

glb_img = F.interpolate(image_tensor.float(), size=(resolution, resolution), mode="bicubic", align_corners=False).to(sub_img.dtype).unsqueeze(0)

image_tensor = torch.cat([sub_img, glb_img]) # .unsqueeze(0)
return image_tensor


def load_video(video_path, num_segments=16, return_msg=False, resolution=224, hd_num=6, padding=False):
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
num_frames = len(vr) - 1

frame_indices = get_index(max_frame=num_frames, num_segments=num_segments, fps=float(vr.get_avg_fps()), first_idx=0, bound=None)
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

transform = T.Compose([T.Lambda(lambda x: x.float().div(255.0)), T.Normalize(mean, std)])

frames = vr.get_batch(frame_indices)
frames = frames.permute(0, 3, 1, 2)

if padding:
frames = HD_transform_padding(frames.float(), image_size=resolution, hd_num=hd_num)
else:
frames = HD_transform_no_padding(frames.float(), image_size=resolution, hd_num=hd_num)

frames = transform(frames)
T_, C, H, W = frames.shape

sub_img = frames.reshape(1, T_, 3, H // resolution, resolution, W // resolution, resolution).permute(0, 3, 5, 1, 2, 4, 6).reshape(-1, T_, 3, resolution, resolution).contiguous()

glb_img = F.interpolate(frames.float(), size=(resolution, resolution), mode="bicubic", align_corners=False).to(sub_img.dtype).unsqueeze(0)

frames = torch.cat([sub_img, glb_img]).unsqueeze(0)

if return_msg:
fps = float(vr.get_avg_fps())
sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
# " " should be added in the start and end
msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
return frames, msg
else:
return frames


def HD_transform_padding(frames, image_size=224, hd_num=6):
def _padding_224(frames):
_, _, H, W = frames.shape
tar = int(np.ceil(H / 224) * 224)
top_padding = (tar - H) // 2
bottom_padding = tar - H - top_padding
left_padding = 0
right_padding = 0

padded_frames = F.pad(frames, pad=[left_padding, right_padding, top_padding, bottom_padding], mode="constant", value=255)
return padded_frames

_, _, H, W = frames.shape
trans = False
if W < H:
frames = frames.flip(-2, -1)
trans = True
width, height = H, W
else:
width, height = W, H

ratio = width / height
scale = 1
while scale * np.ceil(scale / ratio) <= hd_num:
scale += 1
scale -= 1
new_w = int(scale * image_size)
new_h = int(new_w / ratio)

resized_frames = F.interpolate(frames, size=(new_h, new_w), mode="bicubic", align_corners=False)
padded_frames = _padding_224(resized_frames)

if trans:
padded_frames = padded_frames.flip(-2, -1)

return padded_frames


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio


def HD_transform_no_padding(frames, image_size=224, hd_num=6, fix_ratio=(2, 1)):
min_num = 1
max_num = hd_num
_, _, orig_height, orig_width = frames.shape
aspect_ratio = orig_width / orig_height

# calculate the existing video aspect ratio
target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

# find the closest aspect ratio to the target
if fix_ratio:
target_aspect_ratio = fix_ratio
else:
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)

# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

# resize the frames
resized_frame = F.interpolate(frames, size=(target_height, target_width), mode="bicubic", align_corners=False)
return resized_frame


@register_model("InternVideo2")
class InternVideo2(lmms):
def __init__(
self,
pretrained: str = "OpenGVLab/InternVideo2_chat_8B_HD",
modality: str = "video",
device: str = "cuda:0",
device_map: str = "cuda:0",
batch_size: str = "1",
num_segments: str = "8",
hd_num: str = "6",
**kwargs,
):
super().__init__()
self.path = pretrained
self.instruction = "Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons.\n"

self._tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True, use_fast=False)
self._model = AutoModel.from_pretrained(self.path, torch_dtype=torch.bfloat16, trust_remote_code=True).eval().cuda()
batch_size = int(batch_size)
self.num_segments = int(num_segments)
self.hd_num = int(hd_num)
assert batch_size == 1, f"Batch size should be 1 for InternVideo2, but got {batch_size}."
self.batch_size_per_gpu = batch_size
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
self.accelerator = accelerator
if accelerator.num_processes > 1:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"
elif accelerator.num_processes == 1 and device_map == "auto":
self._device = torch.device(device)
self.device_map = device_map
else:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"

if accelerator.num_processes > 1:
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
# Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
# I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
if accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs = {
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
"train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,
}
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0")

if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED:
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
elif accelerator.num_processes == 1 and device_map == "auto":
eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism")
self._rank = 0
self._word_size = 1
else:
eval_logger.info(f"Using single device: {self._device}")
self.model.to(self._device)
self._rank = 0
self._world_size = 1

self.modality = modality

@property
def config(self):
# return the associated transformers.AutoConfig for the given pretrained model.
return self._config

@property
def tokenizer(self):
return self._tokenizer

@property
def model(self):
# returns the model, unwrapping it if using Accelerate
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self._model)
else:
return self._model

@property
def batch_size(self):
return self.batch_size_per_gpu

@property
def device(self):
return self._device

@property
def rank(self):
return self._rank

@property
def world_size(self):
return self._world_size

def flatten(self, input):
new_list = []
for i in input:
for j in i:
new_list.append(j)
return new_list

def generate_until(self, requests) -> List[str]:
res = []
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")

for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
if "until" in gen_kwargs:
gen_kwargs.pop("until")
for k, v in DEFAULT_GEN_KWARGS.items():
if k not in gen_kwargs:
gen_kwargs[k] = v

pop_keys = []
for k, v in gen_kwargs.items():
if k not in DEFAULT_GEN_KWARGS:
pop_keys.append(k)

for k in pop_keys:
gen_kwargs.pop(k)

visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
visuals = self.flatten(visuals)
if self.modality == "image":
image_path = visuals[0]
pixel_values = load_image(image_path, resolution=224, hd_num=self.hd_num)
pixel_values = pixel_values.to(torch.bfloat16).cuda()
question = contexts
response, history = self.model.chat(self.tokenizer, msg="", user_prompt=question, media_type="image", media_tensor=pixel_values, instruction=None, chat_history=[], return_history=True, **gen_kwargs)
elif self.modality == "video":
assert len(visuals) == 1, f"Only one video is supported, but got {len(visuals)} videos. [META-INFO]{visuals}"
video_path = visuals[0]
if "mvbench" in task:
answer_prompt = "Best Option:("
else:
answer_prompt = None
pixel_values = load_video(video_path, num_segments=self.num_segments, return_msg=False, resolution=224, hd_num=self.hd_num)
pixel_values = pixel_values.to(torch.bfloat16).cuda()
question = self.instruction + contexts
response, history = self.model.chat(
self.tokenizer,
msg="",
user_prompt=question,
media_type="video",
media_tensor=pixel_values,
instruction=self.instruction,
chat_history=[],
return_history=True,
generation_config=gen_kwargs,
answer_prompt=answer_prompt,
debug_conv=False,
)
res.append(response)
pbar.update(1)
pbar.close()
return res

def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
assert False, "Not implemented yet."

def generate_until_multi_round(self, requests) -> List[str]:
raise NotImplementedError("TODO: Implement multi-round generation for InternVideo2")
Loading

0 comments on commit af395ae

Please sign in to comment.