From 3b19e39dc59975c4240082a6b45f7c2a123b7164 Mon Sep 17 00:00:00 2001 From: nunjunj <106306814+nunjunj@users.noreply.github.com> Date: Fri, 16 Aug 2024 09:41:34 +0700 Subject: [PATCH] Chat method for offline llm (#5049) Co-authored-by: nunjunj Co-authored-by: nunjunj Co-authored-by: nunjunj Co-authored-by: Cyrus Leung Co-authored-by: DarkLight1337 --- .buildkite/test-pipeline.yaml | 1 + examples/offline_inference_chat.py | 53 +++++++++++ tests/entrypoints/llm/test_generate.py | 19 ++++ vllm/entrypoints/llm.py | 124 +++++++++++++++++++------ 4 files changed, 168 insertions(+), 29 deletions(-) create mode 100644 examples/offline_inference_chat.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6b3dbb1ccb7d8..8c0fc8e05a33e 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -147,6 +147,7 @@ steps: - pip install awscli tensorizer # for llava example and tensorizer test - python3 offline_inference.py - python3 cpu_offload.py + - python3 offline_inference_chat.py - python3 offline_inference_with_prefix.py - python3 llm_engine_example.py - python3 offline_inference_vision_language.py diff --git a/examples/offline_inference_chat.py b/examples/offline_inference_chat.py new file mode 100644 index 0000000000000..c2020724c72fe --- /dev/null +++ b/examples/offline_inference_chat.py @@ -0,0 +1,53 @@ +from vllm import LLM, SamplingParams + +llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct") +sampling_params = SamplingParams(temperature=0.5) + + +def print_outputs(outputs): + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print("-" * 80) + + +print("=" * 80) + +# In this script, we demonstrate how to pass input to the chat method: + +conversation = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "Hello" + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + { + "role": "user", + "content": "Write an essay about the importance of higher education.", + }, +] +outputs = llm.chat(conversation, + sampling_params=sampling_params, + use_tqdm=False) +print_outputs(outputs) + +# A chat template can be optionally supplied. +# If not, the model will use its default chat template. + +# with open('template_falcon_180b.jinja', "r") as f: +# chat_template = f.read() + +# outputs = llm.chat( +# conversations, +# sampling_params=sampling_params, +# use_tqdm=False, +# chat_template=chat_template, +# ) diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index 57ac37f7ea8f7..c426e9b4ee899 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -140,3 +140,22 @@ def test_multiple_sampling_params(llm: LLM): # sampling_params is None, default params should be applied outputs = llm.generate(PROMPTS, sampling_params=None) assert len(PROMPTS) == len(outputs) + + +def test_chat(): + + llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct") + + prompt1 = "Explain the concept of entropy." + messages = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt1 + }, + ] + outputs = llm.chat(messages) + assert len(outputs) == 1 diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 175f418a1294f..32bdb2b7d14f4 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -6,6 +6,9 @@ from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine +from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + apply_chat_template, + parse_chat_messages) from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger @@ -87,7 +90,7 @@ class LLM: disable_custom_all_reduce: See ParallelConfig **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See :ref:`engine_args`) - + Note: This class is intended to be used for offline inference. For online serving, use the :class:`~vllm.AsyncLLMEngine` class instead. @@ -138,8 +141,12 @@ def __init__( if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True - removed_vision_keys = ("image_token_id", "image_feature_size", - "image_input_shape", "image_input_type") + removed_vision_keys = ( + "image_token_id", + "image_feature_size", + "image_input_shape", + "image_input_type", + ) if any(k in kwargs for k in removed_vision_keys): raise TypeError( "There is no need to pass vision-related arguments anymore.") @@ -259,11 +266,12 @@ def generate( ) -> List[RequestOutput]: ... - @deprecate_kwargs("prompts", - "prompt_token_ids", - is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'inputs' parameter " - "instead.") + @deprecate_kwargs( + "prompts", + "prompt_token_ids", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'inputs' parameter instead.", + ) def generate( self, prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], @@ -286,17 +294,17 @@ def generate( Args: inputs: A list of inputs to generate completions for. sampling_params: The sampling parameters for text generation. If - None, we use the default sampling parameters. - When it is a single value, it is applied to every prompt. - When it is a list, the list must have the same length as the + None, we use the default sampling parameters. + When it is a single value, it is applied to every prompt. + When it is a list, the list must have the same length as the prompts and it is paired one by one with the prompt. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. - prompt_adapter_request: Prompt Adapter request to use for + prompt_adapter_request: Prompt Adapter request to use for generation, if any. Returns: - A list of `RequestOutput` objects containing the + A list of ``RequestOutput`` objects containing the generated completions in the same order as the input prompts. Note: @@ -339,6 +347,62 @@ def generate( outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) + def chat( + self, + messages: List[ChatCompletionMessageParam], + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + chat_template: Optional[str] = None, + add_generation_template: bool = True, + ) -> List[RequestOutput]: + """ + Generates responses for chat messages. + + Converts the messages to prompts using the tokenizer and calls + the :meth:`generate` method to generate the responses. + + Args: + messages: A list of messages to generate responses for. Each + message is a list of dictionaries with 'role' and 'content' + keys. + sampling_params: The sampling parameters for text generation. + If None, we use the default sampling parameters. When it + is a single value, it is applied to every prompt. When it + is a list, the list must have the same length as the + prompts and it is paired one by one with the prompt. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + chat_template: The template to use for structuring the chat. + If not provided, the model's default chat template will be used. + add_generation_template: If True, adds a generation template + to each message. + + Returns: + A list of ``RequestOutput`` objects containing the generated + responses in the same order as the input messages. + """ + + tokenizer = self.get_tokenizer() + model_config = self.llm_engine.get_model_config() + + conversations, _ = parse_chat_messages(messages, model_config, + tokenizer) + + prompts = apply_chat_template( + tokenizer, + conversations, + chat_template=chat_template, + add_generation_template=add_generation_template) + + return self.generate( + prompts, + sampling_params, + use_tqdm=use_tqdm, + lora_request=lora_request, + ) + @overload # LEGACY: single (prompt + optional token ids) def encode( self, @@ -413,11 +477,12 @@ def encode( ) -> List[EmbeddingRequestOutput]: ... - @deprecate_kwargs("prompts", - "prompt_token_ids", - is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'inputs' parameter " - "instead.") + @deprecate_kwargs( + "prompts", + "prompt_token_ids", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'inputs' parameter instead.", + ) def encode( self, prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], @@ -443,7 +508,7 @@ def encode( use the default pooling parameters. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. - prompt_adapter_request: Prompt Adapter request to use for + prompt_adapter_request: Prompt Adapter request to use for generation, if any. Returns: @@ -563,15 +628,15 @@ def _validate_and_add_requests( params[i] if isinstance(params, Sequence) else params, lora_request=lora_request[i] if isinstance( lora_request, Sequence) else lora_request, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + ) def _add_request( - self, - inputs: PromptInputs, - params: Union[SamplingParams, PoolingParams], - lora_request: Optional[Union[List[LoRARequest], - LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + self, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: request_id = str(next(self.request_counter)) self.llm_engine.add_request( @@ -579,7 +644,8 @@ def _add_request( inputs, params, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + ) def _add_guided_processor( self, @@ -628,8 +694,8 @@ def _run_engine( in_spd = total_in_toks / pbar.format_dict["elapsed"] total_out_toks += sum( len(stp.token_ids) for stp in output.outputs) - out_spd = total_out_toks / pbar.format_dict[ - "elapsed"] + out_spd = (total_out_toks / + pbar.format_dict["elapsed"]) pbar.postfix = ( f"est. speed input: {in_spd:.2f} toks/s, " f"output: {out_spd:.2f} toks/s")