Skip to content

Commit

Permalink
Update README to include GPT4V
Browse files Browse the repository at this point in the history
Signed-off-by: Linkun Chen <[email protected]>
  • Loading branch information
Linkun Chen committed Nov 5, 2024
1 parent 13c726c commit 06c36be
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 59 deletions.
11 changes: 11 additions & 0 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,14 @@ You can download the dataset by running:
```bash
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
```

## Downloading the ShareGPT4V dataset

The json file refers to several image datasets (coco, llava, etc.). The benchmark scripts
will ignore a datapoint if the referred image is missing.
```bash
wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/resolve/main/sharegpt4v_instruct_gpt4-vision_cap100k.json
mkdir coco -p
wget http://images.cocodataset.org/zips/train2017.zip -O coco/train2017.zip
unzip coco/train2017.zip -d coco/
```
59 changes: 0 additions & 59 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch
import uvloop
from PIL import Image
from PIL import Image
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
Expand Down Expand Up @@ -61,33 +60,6 @@ def _get_prompt_for_image_model(question: str, *, model: str) -> str:
raise ValueError(f"Unsupported model {model}")


def sample_requests(tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace) -> List[SampleRequest]:
dataset_path: str = args.dataset
num_requests: int = args.num_prompts
fixed_output_len: Optional[int] = args.output_len
model: str = args.model
def _get_prompt_for_image_model(question: str, *, model: str) -> str:
"""Prepend and append special tokens around the question to form a prompt.
Args:
question: The input question text to wrap with special tokens
model: The name of the model being used, to determine which special
tokens to add
Returns:
The formatted prompt string with appropriate special tokens for the
model
Raises:
ValueError: If an unsupported model name is provided
"""
model = model.lower()
if "pixtral" in model:
return f"<s>[INST]{question}\n[IMG][/INST]"
raise ValueError(f"Unsupported model {model}")


def sample_requests(tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace) -> List[SampleRequest]:
dataset_path: str = args.dataset
Expand All @@ -107,7 +79,6 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,

# Filter out sequences that are too long or too short
filtered_dataset: List[SampleRequest] = []
for data in dataset:
for data in dataset:
if len(filtered_dataset) == num_requests:
break
Expand All @@ -116,25 +87,6 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
prompt = data["conversations"][0]["value"]
completion = data["conversations"][1]["value"]

multi_modal_data: Optional[MultiModalDataDict] = None
if "image" in data:
multi_modal_data = multi_modal_data or {}
image_path = data["image"]
# TODO(vllm-project/vllm/issues/9778): Support multiple images.
assert isinstance(image_path,
str), "Only support single image input"
try:
multi_modal_data["image"] = Image.open(image_path).convert(
"RGB")
except FileNotFoundError:
# Ignore datapoint where asset is missing
continue
prompt = _get_prompt_for_image_model(question=prompt, model=model)

# Only keep the first two turns of each conversation.
prompt = data["conversations"][0]["value"]
completion = data["conversations"][1]["value"]

multi_modal_data: Optional[MultiModalDataDict] = None
if "image" in data:
multi_modal_data = multi_modal_data or {}
Expand Down Expand Up @@ -167,8 +119,6 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=multi_modal_data))
expected_output_len=output_len,
multi_modal_data=multi_modal_data))

return filtered_dataset

Expand All @@ -185,9 +135,6 @@ def run_vllm(
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
for request in requests:
prompts.append(
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
prompts.append(
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
Expand Down Expand Up @@ -239,9 +186,6 @@ async def run_vllm_async(
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
for request in requests:
prompts.append(
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
prompts.append(
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
Expand Down Expand Up @@ -369,10 +313,7 @@ def main(args: argparse.Namespace):
]
else:
requests = sample_requests(tokenizer, args)
requests = sample_requests(tokenizer, args)

is_multi_modal = any(request.multi_modal_data is not None
for request in requests)
is_multi_modal = any(request.multi_modal_data is not None
for request in requests)
if args.backend == "vllm":
Expand Down

0 comments on commit 06c36be

Please sign in to comment.