From fe8b351cd39f07ef44175a6b1c7102315ec7f415 Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Thu, 13 Feb 2025 14:13:31 +0000 Subject: [PATCH 1/3] added support for vlm in offline inference --- .../engine/offline_batch_inference_vlm.py | 67 +++++++++++++++++++ python/sglang/srt/entrypoints/engine.py | 13 ++++ 2 files changed, 80 insertions(+) create mode 100644 examples/runtime/engine/offline_batch_inference_vlm.py diff --git a/examples/runtime/engine/offline_batch_inference_vlm.py b/examples/runtime/engine/offline_batch_inference_vlm.py new file mode 100644 index 00000000000..cd0a62e1d29 --- /dev/null +++ b/examples/runtime/engine/offline_batch_inference_vlm.py @@ -0,0 +1,67 @@ +""" +Usage: +python3 offline_batch_inference.py --model meta-llama/Llama-3.1-8B-Instruct +""" + +import argparse +import dataclasses + +from transformers import AutoProcessor + +import sglang as sgl +from sglang.srt.openai_api.adapter import v1_chat_generate_request +from sglang.srt.openai_api.protocol import ChatCompletionRequest +from sglang.srt.server_args import ServerArgs + + +def main( + server_args: ServerArgs, +): + # Create an LLM. + vlm = sgl.Engine(**dataclasses.asdict(server_args)) + + # prepare prompts. + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What’s in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true", + }, + }, + ], + } + ] + request = ChatCompletionRequest( + messages=messages, + model="Qwen/Qwen2-VL-7B-Instruct", + temperature=0.8, + top_p=0.95, + ) + gen_request, _ = v1_chat_generate_request( + [request], + vlm.tokenizer_manager, + ) + + outputs = vlm.generate( + input_ids=gen_request.input_ids, + image_data=gen_request.image_data, + sampling_params=gen_request.sampling_params, + ) + + print("===============================") + print(f"Prompt: {messages[0]['content'][0]['text']}") + print(f"Generated text: {outputs['text']}") + + +# The __main__ condition is necessary here because we use "spawn" to create subprocesses +# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine +if __name__ == "__main__": + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + args = parser.parse_args() + server_args = ServerArgs.from_cli_args(args) + main(server_args) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 942a53c37cf..30978440f4c 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -115,6 +115,9 @@ def generate( sampling_params: Optional[Union[List[Dict], Dict]] = None, # The token ids for text; one can either specify text or input_ids. input_ids: Optional[Union[List[List[int]], List[int]]] = None, + # The image input. It can be a file name, a url, or base64 encoded string. + # See also python/sglang/srt/utils.py:load_image. + image_data: Optional[Union[List[str], str]] = None, return_logprob: Optional[Union[List[bool], bool]] = False, logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, @@ -126,14 +129,20 @@ def generate( The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. Please refer to `GenerateReqInput` for the documentation. """ + modalities_list = [] + if image_data is not None: + modalities_list.append("image") + obj = GenerateReqInput( text=prompt, input_ids=input_ids, sampling_params=sampling_params, + image_data=image_data, return_logprob=return_logprob, logprob_start_len=logprob_start_len, top_logprobs_num=top_logprobs_num, lora_path=lora_path, + modalities=modalities_list, custom_logit_processor=custom_logit_processor, stream=stream, ) @@ -162,6 +171,9 @@ async def async_generate( sampling_params: Optional[Union[List[Dict], Dict]] = None, # The token ids for text; one can either specify text or input_ids. input_ids: Optional[Union[List[List[int]], List[int]]] = None, + # The image input. It can be a file name, a url, or base64 encoded string. + # See also python/sglang/srt/utils.py:load_image. + image_data: Optional[Union[List[str], str]] = None, return_logprob: Optional[Union[List[bool], bool]] = False, logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, @@ -177,6 +189,7 @@ async def async_generate( text=prompt, input_ids=input_ids, sampling_params=sampling_params, + image_data=image_data, return_logprob=return_logprob, logprob_start_len=logprob_start_len, top_logprobs_num=top_logprobs_num, From 0cdbed383c01d2e4ab471808f99ae04b1b985c1e Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Thu, 13 Feb 2025 14:40:04 +0000 Subject: [PATCH 2/3] fixed usage doc --- examples/runtime/engine/offline_batch_inference_vlm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/runtime/engine/offline_batch_inference_vlm.py b/examples/runtime/engine/offline_batch_inference_vlm.py index cd0a62e1d29..bbd32acc047 100644 --- a/examples/runtime/engine/offline_batch_inference_vlm.py +++ b/examples/runtime/engine/offline_batch_inference_vlm.py @@ -1,6 +1,6 @@ """ Usage: -python3 offline_batch_inference.py --model meta-llama/Llama-3.1-8B-Instruct +python offline_batch_inference_vlm.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template=qwen2-vl """ import argparse @@ -35,14 +35,14 @@ def main( ], } ] - request = ChatCompletionRequest( + chat_request = ChatCompletionRequest( messages=messages, model="Qwen/Qwen2-VL-7B-Instruct", temperature=0.8, top_p=0.95, ) gen_request, _ = v1_chat_generate_request( - [request], + [chat_request], vlm.tokenizer_manager, ) From fdc82a11b697ee05dc64a338766d937bc6a8d2bd Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Thu, 13 Feb 2025 14:56:35 +0000 Subject: [PATCH 3/3] fixed typo --- examples/runtime/engine/offline_batch_inference_vlm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/runtime/engine/offline_batch_inference_vlm.py b/examples/runtime/engine/offline_batch_inference_vlm.py index bbd32acc047..808d0fce9b7 100644 --- a/examples/runtime/engine/offline_batch_inference_vlm.py +++ b/examples/runtime/engine/offline_batch_inference_vlm.py @@ -1,6 +1,6 @@ """ Usage: -python offline_batch_inference_vlm.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template=qwen2-vl +python offline_batch_inference_vlm.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template=qwen2-vl """ import argparse @@ -37,7 +37,7 @@ def main( ] chat_request = ChatCompletionRequest( messages=messages, - model="Qwen/Qwen2-VL-7B-Instruct", + model=server_args.model_path, temperature=0.8, top_p=0.95, )