-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
1,037 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
## Run benchmark | ||
|
||
### Benchmark sglang | ||
|
||
``` | ||
python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl | ||
``` | ||
|
||
### Benchmark hf | ||
|
||
``` | ||
python benchmark/mmmu/bench_other.py --model-path Qwen/Qwen2-VL-7B-Instruct | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
""" | ||
Bench the huggingface vLM with benchmark MMMU | ||
Usage: | ||
python benchmark/mmmu/bench_other.py --model-path Qwen/Qwen2-VL-7B-Instruct | ||
The eval output will be logged | ||
""" | ||
|
||
import argparse | ||
import random | ||
|
||
import torch | ||
from tqdm import tqdm | ||
from transformers import AutoModelForImageTextToText, AutoProcessor | ||
|
||
from benchmark.mmmu.bench_sglang import EvalArgs, prepare_samples | ||
from benchmark.mmmu.data_utils import save_json | ||
from benchmark.mmmu.eval_utils import eval_result, parse_multi_choice_response | ||
|
||
|
||
@torch.no_grad() | ||
def eval_mmmu(args): | ||
eval_args = EvalArgs.from_cli_args(args) | ||
|
||
model = AutoModelForImageTextToText.from_pretrained( | ||
args.model_path, torch_dtype="auto", device_map="auto", trust_remote_code=True | ||
) | ||
model = model.eval().cuda() | ||
processor = AutoProcessor.from_pretrained( | ||
args.model_path, torch_dtype="auto", device_map="auto" | ||
) | ||
|
||
samples = prepare_samples(eval_args) | ||
out_samples = dict() | ||
|
||
max_new_tokens = 128 | ||
temperature = 0.0 | ||
|
||
sampling_params = { | ||
"do_sample": False, | ||
"temperature": temperature, | ||
"max_new_tokens": max_new_tokens, | ||
} | ||
|
||
answer_dict = {} | ||
for sample in tqdm(samples): | ||
prompt = sample["final_input_prompt"] | ||
image = sample["image"] | ||
prefix = prompt.split("<")[0] | ||
suffix = prompt.split(">")[1] | ||
if image is not None: | ||
messages = [ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{"type": "text", "text": prefix}, | ||
{ | ||
"type": "image", | ||
"image": image, | ||
}, | ||
{"type": "text", "text": suffix}, | ||
], | ||
} | ||
] | ||
text = processor.apply_chat_template( | ||
messages, tokenize=False, add_generation_prompt=True | ||
) | ||
inputs = processor( | ||
text=[text], | ||
images=[image], | ||
padding=True, | ||
return_tensors="pt", | ||
) | ||
|
||
inputs = inputs.to("cuda") | ||
|
||
generated_ids = model.generate(**inputs, **sampling_params) | ||
|
||
response = processor.decode( | ||
generated_ids[0], | ||
skip_special_tokens=True, | ||
clean_up_tokenization_spaces=False, | ||
)[len(text) :] | ||
else: # multiple images actually | ||
if sample["question_type"] == "multiple-choice": | ||
all_choices = sample["all_choices"] | ||
response = random.choice(all_choices) | ||
|
||
else: | ||
response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS" | ||
|
||
if sample["question_type"] == "multiple-choice": | ||
pred_ans = parse_multi_choice_response( | ||
response, sample["all_choices"], sample["index2ans"] | ||
) | ||
else: # open question | ||
pred_ans = response | ||
out_samples[sample["id"]] = pred_ans | ||
|
||
# set ground truth answer | ||
answer_dict[sample["id"]] = { | ||
"question_type": sample["question_type"], | ||
"ground_truth": ( | ||
sample["correct_choice"] | ||
if "correct_choice" in samples | ||
else sample["answer"] | ||
), | ||
} | ||
|
||
args.output_path = f"{args.model_path}_val.json" | ||
save_json(args.output_path, out_samples) | ||
eval_result(output_path=args.output_path, answer_dict=answer_dict) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--model-path", | ||
type=str, | ||
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.", | ||
required=True, | ||
) | ||
EvalArgs.add_cli_args(parser) | ||
args = parser.parse_args() | ||
|
||
eval_mmmu(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
""" | ||
Bench the sglang-hosted vLM with benchmark MMMU | ||
Usage: | ||
python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl | ||
The eval output will be logged | ||
""" | ||
|
||
import argparse | ||
import dataclasses | ||
import json | ||
import random | ||
import re | ||
from io import BytesIO | ||
|
||
from tqdm import tqdm | ||
|
||
from benchmark.mmmu.data_utils import save_json | ||
from benchmark.mmmu.eval_utils import ( | ||
EvalArgs, | ||
eval_result, | ||
parse_multi_choice_response, | ||
prepare_samples, | ||
) | ||
from sglang import Engine | ||
from sglang.srt.conversation import chat_templates | ||
from sglang.srt.server_args import ServerArgs | ||
|
||
|
||
def eval_mmmu(args): | ||
server_args = ServerArgs.from_cli_args(args) | ||
eval_args = EvalArgs.from_cli_args(args) | ||
|
||
if server_args.chat_template is None: | ||
raise ValueError("Chat template must be provided for this benchmark") | ||
|
||
samples = prepare_samples(eval_args) | ||
|
||
backend = Engine(**dataclasses.asdict(server_args)) | ||
|
||
out_samples = dict() | ||
|
||
extra_request_body = {} | ||
if eval_args.extra_request_body: | ||
extra_request_body = json.loads(args.extra_request_body) | ||
|
||
max_new_tokens = 128 | ||
temperature = 0.0 | ||
|
||
sampling_params = { | ||
"temperature": temperature, | ||
"max_new_tokens": max_new_tokens, | ||
**extra_request_body, | ||
} | ||
|
||
conv = chat_templates[server_args.chat_template].copy() | ||
image_token = conv.image_token | ||
answer_dict = {} | ||
for sample in tqdm(samples): | ||
prompt = sample["final_input_prompt"] | ||
image = sample["image"] | ||
bytes_io = BytesIO() | ||
image.save(bytes_io, format="PNG") | ||
png_bytes = bytes_io.getvalue() | ||
|
||
prompt = re.sub(r"<[^>]*>", image_token, prompt) | ||
|
||
if image is not None: | ||
gen_out = backend.generate( | ||
prompt=prompt, image_data=png_bytes, sampling_params=sampling_params | ||
)["text"] | ||
|
||
response = gen_out | ||
# print("response: ", response) | ||
else: # multiple images actually | ||
if sample["question_type"] == "multiple-choice": | ||
all_choices = sample["all_choices"] | ||
response = random.choice(all_choices) | ||
|
||
else: | ||
response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS" | ||
|
||
if sample["question_type"] == "multiple-choice": | ||
pred_ans = parse_multi_choice_response( | ||
response, sample["all_choices"], sample["index2ans"] | ||
) | ||
else: # open question | ||
pred_ans = response | ||
out_samples[sample["id"]] = pred_ans | ||
|
||
# set ground truth answer | ||
answer_dict[sample["id"]] = { | ||
"question_type": sample["question_type"], | ||
"ground_truth": ( | ||
sample["correct_choice"] | ||
if "correct_choice" in samples | ||
else sample["answer"] | ||
), | ||
} | ||
|
||
random_part = str(random.randint(0, 999999)) | ||
args.output_path = f"{args.model_path}_val_{random_part}.json" | ||
save_json(args.output_path, out_samples) | ||
eval_result(output_path=args.output_path, answer_dict=answer_dict) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
ServerArgs.add_cli_args(parser) | ||
EvalArgs.add_cli_args(parser) | ||
args = parser.parse_args() | ||
|
||
eval_mmmu(args) |
Oops, something went wrong.