From e4b1d87926cdf3fe66f8e59d860f54638ae7336e Mon Sep 17 00:00:00 2001 From: nunjunj Date: Tue, 23 Jul 2024 23:04:54 +0700 Subject: [PATCH 01/12] add chat method --- .buildkite/test-pipeline.yaml | 1 + examples/offline_inference_chat.py | 95 ++++++++++++++++++++++++++ tests/entrypoints/llm/test_generate.py | 40 +++++++++++ vllm/entrypoints/llm.py | 55 ++++++++++++++- 4 files changed, 190 insertions(+), 1 deletion(-) create mode 100644 examples/offline_inference_chat.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e7dd1fdb2e660..f29a23a1105e5 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -146,6 +146,7 @@ steps: - pip install awscli tensorizer - python3 offline_inference.py - python3 cpu_offload.py + - python3 offline_inference_chat.py - python3 offline_inference_with_prefix.py - python3 llm_engine_example.py - python3 llava_example.py diff --git a/examples/offline_inference_chat.py b/examples/offline_inference_chat.py new file mode 100644 index 0000000000000..b2e10b4af7ffd --- /dev/null +++ b/examples/offline_inference_chat.py @@ -0,0 +1,95 @@ +from vllm import LLM, SamplingParams + +llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct") +sampling_params = SamplingParams(temperature=0.5) + + +def print_outputs(outputs): + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print("-" * 80) + + +print("=" * 80) + +# In this script, we demonstrate four different ways to pass input to the chat method of the LLM class: + +# Conversation with a list of dictionaries +conversation = [ + { + 'role': 'system', + 'content': "You are a helpful assistant" + }, + { + 'role': 'user', + 'content': "Hello" + }, + { + 'role': 'assistant', + 'content': "Hello! How can I assist you today?" + }, + { + 'role': 'user', + 'content': "Write an essay about the importance of higher education." + }, +] +outputs = llm.chat(conversation, + sampling_params=sampling_params, + use_tqdm=False) +print_outputs(outputs) + +# Multiple conversations +conversations = [ + [ + { + 'role': 'system', + 'content': "You are a helpful assistant" + }, + { + 'role': 'user', + 'content': "What is dark matter?" + }, + ], + [ + { + 'role': 'system', + 'content': "You are a helpful assistant" + }, + { + 'role': 'user', + 'content': "How are you?" + }, + { + 'role': + 'assistant', + 'content': + "I'm an AI, so I don't have feelings, but I'm here to help you!" + }, + { + 'role': 'user', + 'content': "Tell me a joke." + }, + ], +] + +outputs = llm.chat( + conversations, + sampling_params=sampling_params, + use_tqdm=False, +) +print_outputs(outputs) + +# A chat template can be optionally supplied. +# If not, the model will use its default chat template. + +# with open('template_falcon_180b.jinja', "r") as f: +# chat_template = f.read() + +# outputs = llm.chat( +# conversations, +# sampling_params=sampling_params, +# use_tqdm=False, +# chat_template=chat_template, +# ) diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index 57ac37f7ea8f7..6e94f684b155d 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -140,3 +140,43 @@ def test_multiple_sampling_params(llm: LLM): # sampling_params is None, default params should be applied outputs = llm.generate(PROMPTS, sampling_params=None) assert len(PROMPTS) == len(outputs) + +def test_chat(): + + llm = LLM(model=MODEL_NAME) + + prompt1 = "Explain the concept of entropy." + messages = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt1 + }, + ] + outputs = llm.chat(messages) + assert len(outputs) == 1 + + prompt2 = "Describe Bangkok in 150 words." + messages = [messages] + [[ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt2 + }, + ]] + outputs = llm.chat(messages) + assert len(outputs) == len(messages) + + sampling_params = [ + SamplingParams(temperature=0.01, top_p=0.95), + SamplingParams(temperature=0.3, top_p=0.95), + ] + + outputs = llm.chat(messages, sampling_params=sampling_params) + assert len(outputs) == len(messages) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 62309ed345b1d..63388376aef49 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import ClassVar, List, Optional, Sequence, Union, cast, overload +from typing import ClassVar, Dict, List, Optional, Sequence, Union, cast, overload from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -315,6 +315,59 @@ def generate( outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) + + def chat( + self, + messages: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + chat_template: Optional[str] = None, + ) -> List[RequestOutput]: + """ + Generates responses for chat messages. + Converts the messages to prompts using the tokenizer and calls + the `generate` method to generate the responses. + Args: + messages: A list of messages to generate responses for. Each + message is a list of dictionaries with 'role' and 'content' + keys. + sampling_params: The sampling parameters for text generation. + If None, we use the default sampling parameters. When it + is a single value, it is applied to every prompt. When it + is a list, the list must have the same length as the + prompts and it is paired one by one with the prompt. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + chat_template: The template to use for structuring the chat. + If not provided, the model's default chat template will be used. + Returns: + A list of `RequestOutput` objects containing the generated + responses in the same order as the input messages. + """ + + tokenizer = self.get_tokenizer() + + if isinstance(messages[0], dict): + # Apply chat templates for chat inputs. + prompts = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_template=True) + + elif isinstance(messages[0], list): + tokenizer = self.get_tokenizer() + prompts = [ + tokenizer.apply_chat_template(message, + tokenize=False, + add_generation_template=True) + for message in messages + ] + + return self.generate(prompts, + sampling_params, + use_tqdm=use_tqdm, + lora_request=lora_request, + chat_template=chat_template) @overload # LEGACY: single (prompt + optional token ids) def encode( From 189b1003770d309e55140c89284095aeba7b37c8 Mon Sep 17 00:00:00 2001 From: nunjunj Date: Tue, 23 Jul 2024 09:39:26 -0700 Subject: [PATCH 02/12] add chat_template and add_generation_prompt params --- examples/offline_inference_chat.py | 4 ++-- vllm/entrypoints/llm.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/examples/offline_inference_chat.py b/examples/offline_inference_chat.py index b2e10b4af7ffd..b4d463578406b 100644 --- a/examples/offline_inference_chat.py +++ b/examples/offline_inference_chat.py @@ -14,7 +14,7 @@ def print_outputs(outputs): print("=" * 80) -# In this script, we demonstrate four different ways to pass input to the chat method of the LLM class: +# In this script, we demonstrate two ways to pass input to the chat method of the LLM class: # Conversation with a list of dictionaries conversation = [ @@ -40,7 +40,7 @@ def print_outputs(outputs): use_tqdm=False) print_outputs(outputs) -# Multiple conversations +# Multiple conversations conversations = [ [ { diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 63388376aef49..aa804f636a8d9 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -324,6 +324,7 @@ def chat( use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, chat_template: Optional[str] = None, + add_generation_template: bool = True, ) -> List[RequestOutput]: """ Generates responses for chat messages. @@ -342,6 +343,7 @@ def chat( lora_request: LoRA request to use for generation, if any. chat_template: The template to use for structuring the chat. If not provided, the model's default chat template will be used. + add_generation_template: If True, adds a generation template to each message. Returns: A list of `RequestOutput` objects containing the generated responses in the same order as the input messages. @@ -352,22 +354,25 @@ def chat( if isinstance(messages[0], dict): # Apply chat templates for chat inputs. prompts = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_template=True) + messages, tokenize=False, + add_generation_template=add_generation_template, + chat_template=chat_template) elif isinstance(messages[0], list): tokenizer = self.get_tokenizer() prompts = [ tokenizer.apply_chat_template(message, tokenize=False, - add_generation_template=True) + add_generation_template=add_generation_template, + chat_template=chat_template) for message in messages ] return self.generate(prompts, sampling_params, use_tqdm=use_tqdm, - lora_request=lora_request, - chat_template=chat_template) + lora_request=lora_request + ) @overload # LEGACY: single (prompt + optional token ids) def encode( From 7fcc1bfd359c6bf41d54740dee577423e87d5e6e Mon Sep 17 00:00:00 2001 From: nunjunj Date: Tue, 23 Jul 2024 10:05:40 -0700 Subject: [PATCH 03/12] run format.sh --- examples/offline_inference_chat.py | 60 ++---- tests/entrypoints/llm/test_generate.py | 7 +- vllm/entrypoints/llm.py | 268 +++++++++++++++---------- 3 files changed, 181 insertions(+), 154 deletions(-) diff --git a/examples/offline_inference_chat.py b/examples/offline_inference_chat.py index b4d463578406b..aea8dedd55798 100644 --- a/examples/offline_inference_chat.py +++ b/examples/offline_inference_chat.py @@ -14,63 +14,37 @@ def print_outputs(outputs): print("=" * 80) -# In this script, we demonstrate two ways to pass input to the chat method of the LLM class: +# In this script, we demonstrate two ways to pass input to the chat method: # Conversation with a list of dictionaries conversation = [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hello! How can I assist you today?"}, { - 'role': 'system', - 'content': "You are a helpful assistant" - }, - { - 'role': 'user', - 'content': "Hello" - }, - { - 'role': 'assistant', - 'content': "Hello! How can I assist you today?" - }, - { - 'role': 'user', - 'content': "Write an essay about the importance of higher education." + "role": "user", + "content": "Write an essay about the importance of higher education.", }, ] -outputs = llm.chat(conversation, - sampling_params=sampling_params, - use_tqdm=False) +outputs = llm.chat( + conversation, sampling_params=sampling_params, use_tqdm=False +) print_outputs(outputs) -# Multiple conversations +# Multiple conversations conversations = [ [ - { - 'role': 'system', - 'content': "You are a helpful assistant" - }, - { - 'role': 'user', - 'content': "What is dark matter?" - }, + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "What is dark matter?"}, ], [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "How are you?"}, { - 'role': 'system', - 'content': "You are a helpful assistant" - }, - { - 'role': 'user', - 'content': "How are you?" - }, - { - 'role': - 'assistant', - 'content': - "I'm an AI, so I don't have feelings, but I'm here to help you!" - }, - { - 'role': 'user', - 'content': "Tell me a joke." + "role": "assistant", + "content": "I'm an AI without feelings, but I'm here to help!", }, + {"role": "user", "content": "Tell me a joke."}, ], ] diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index 6e94f684b155d..d994ae9758886 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -141,6 +141,7 @@ def test_multiple_sampling_params(llm: LLM): outputs = llm.generate(PROMPTS, sampling_params=None) assert len(PROMPTS) == len(outputs) + def test_chat(): llm = LLM(model=MODEL_NAME) @@ -160,7 +161,7 @@ def test_chat(): assert len(outputs) == 1 prompt2 = "Describe Bangkok in 150 words." - messages = [messages] + [[ + multiple_messages = [messages] + [[ { "role": "system", "content": "You are a helpful assistant" @@ -170,8 +171,8 @@ def test_chat(): "content": prompt2 }, ]] - outputs = llm.chat(messages) - assert len(outputs) == len(messages) + outputs = llm.chat(multiple_messages) + assert len(outputs) == len(multiple_messages) sampling_params = [ SamplingParams(temperature=0.01, top_p=0.95), diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index aa804f636a8d9..d58bd3bb359a3 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,5 +1,6 @@ from contextlib import contextmanager -from typing import ClassVar, Dict, List, Optional, Sequence, Union, cast, overload +from typing import (ClassVar, Dict, List, Optional, Sequence, Union, cast, + overload) from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -84,7 +85,7 @@ class LLM: disable_custom_all_reduce: See ParallelConfig **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See :ref:`engine_args`) - + Note: This class is intended to be used for offline inference. For online serving, use the :class:`~vllm.AsyncLLMEngine` class instead. @@ -126,11 +127,16 @@ def __init__( ) -> None: if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True - removed_vision_keys = ("image_token_id", "image_feature_size", - "image_input_shape", "image_input_type") + removed_vision_keys = ( + "image_token_id", + "image_feature_size", + "image_input_shape", + "image_input_type", + ) if any(k in kwargs for k in removed_vision_keys): raise TypeError( - "There is no need to pass vision-related arguments anymore.") + "There is no need to pass vision-related arguments anymore." + ) engine_args = EngineArgs( model=model, tokenizer=tokenizer, @@ -153,11 +159,13 @@ def __init__( **kwargs, ) self.llm_engine = LLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.LLM_CLASS) + engine_args, usage_context=UsageContext.LLM_CLASS + ) self.request_counter = Counter() def get_tokenizer( - self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + self + ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: return self.llm_engine.tokenizer.tokenizer def set_tokenizer( @@ -171,14 +179,16 @@ def set_tokenizer( self.llm_engine.tokenizer.tokenizer = tokenizer else: self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer( - tokenizer) + tokenizer + ) @overload # LEGACY: single (prompt + optional token ids) def generate( self, prompts: str, - sampling_params: Optional[Union[SamplingParams, - List[SamplingParams]]] = None, + sampling_params: Optional[ + Union[SamplingParams, List[SamplingParams]] + ] = None, prompt_token_ids: Optional[List[int]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, @@ -189,8 +199,9 @@ def generate( def generate( self, prompts: List[str], - sampling_params: Optional[Union[SamplingParams, - List[SamplingParams]]] = None, + sampling_params: Optional[ + Union[SamplingParams, List[SamplingParams]] + ] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, @@ -201,8 +212,9 @@ def generate( def generate( self, prompts: Optional[str] = None, - sampling_params: Optional[Union[SamplingParams, - List[SamplingParams]]] = None, + sampling_params: Optional[ + Union[SamplingParams, List[SamplingParams]] + ] = None, *, prompt_token_ids: List[int], use_tqdm: bool = True, @@ -214,8 +226,9 @@ def generate( def generate( self, prompts: Optional[List[str]] = None, - sampling_params: Optional[Union[SamplingParams, - List[SamplingParams]]] = None, + sampling_params: Optional[ + Union[SamplingParams, List[SamplingParams]] + ] = None, *, prompt_token_ids: List[List[int]], use_tqdm: bool = True, @@ -240,24 +253,29 @@ def generate( inputs: Union[PromptInputs, Sequence[PromptInputs]], /, # We may enable `inputs` keyword after removing the old API *, - sampling_params: Optional[Union[SamplingParams, - Sequence[SamplingParams]]] = None, + sampling_params: Optional[ + Union[SamplingParams, Sequence[SamplingParams]] + ] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, ) -> List[RequestOutput]: ... - @deprecate_kwargs("prompts", - "prompt_token_ids", - is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'inputs' parameter " - "instead.") + @deprecate_kwargs( + "prompts", + "prompt_token_ids", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'inputs' parameter " "instead.", + ) def generate( self, - prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], - Optional[Union[str, List[str]]]] = None, - sampling_params: Optional[Union[SamplingParams, - Sequence[SamplingParams]]] = None, + prompts: Union[ + Union[PromptInputs, Sequence[PromptInputs]], + Optional[Union[str, List[str]]], + ] = None, + sampling_params: Optional[ + Union[SamplingParams, Sequence[SamplingParams]] + ] = None, prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, @@ -272,13 +290,13 @@ def generate( Args: inputs: A list of inputs to generate completions for. sampling_params: The sampling parameters for text generation. If - None, we use the default sampling parameters. - When it is a single value, it is applied to every prompt. - When it is a list, the list must have the same length as the + None, we use the default sampling parameters. + When it is a single value, it is applied to every prompt. + When it is a list, the list must have the same length as the prompts and it is paired one by one with the prompt. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. - prompt_adapter_request: Prompt Adapter request to use for + prompt_adapter_request: Prompt Adapter request to use for generation, if any. Returns: @@ -293,7 +311,8 @@ def generate( if self.llm_engine.model_config.embedding_mode: raise ValueError( "LLM.generate() is only supported for generation models " - "(XForCausalLM).") + "(XForCausalLM)." + ) if prompt_token_ids is not None: inputs = self._convert_v1_inputs( @@ -311,16 +330,18 @@ def generate( inputs=inputs, params=sampling_params, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + ) outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) - + def chat( self, messages: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], - sampling_params: Optional[Union[SamplingParams, - List[SamplingParams]]] = None, + sampling_params: Optional[ + Union[SamplingParams, List[SamplingParams]] + ] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, chat_template: Optional[str] = None, @@ -328,24 +349,25 @@ def chat( ) -> List[RequestOutput]: """ Generates responses for chat messages. - Converts the messages to prompts using the tokenizer and calls + Converts the messages to prompts using the tokenizer and calls the `generate` method to generate the responses. Args: - messages: A list of messages to generate responses for. Each - message is a list of dictionaries with 'role' and 'content' + messages: A list of messages to generate responses for. Each + message is a list of dictionaries with 'role' and 'content' keys. - sampling_params: The sampling parameters for text generation. - If None, we use the default sampling parameters. When it - is a single value, it is applied to every prompt. When it - is a list, the list must have the same length as the + sampling_params: The sampling parameters for text generation. + If None, we use the default sampling parameters. When it + is a single value, it is applied to every prompt. When it + is a list, the list must have the same length as the prompts and it is paired one by one with the prompt. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. chat_template: The template to use for structuring the chat. If not provided, the model's default chat template will be used. - add_generation_template: If True, adds a generation template to each message. + add_generation_template: If True, adds a generation template + to each message. Returns: - A list of `RequestOutput` objects containing the generated + A list of `RequestOutput` objects containing the generated responses in the same order as the input messages. """ @@ -354,32 +376,38 @@ def chat( if isinstance(messages[0], dict): # Apply chat templates for chat inputs. prompts = tokenizer.apply_chat_template( - messages, tokenize=False, + messages, + tokenize=False, add_generation_template=add_generation_template, - chat_template=chat_template) - + chat_template=chat_template, + ) + elif isinstance(messages[0], list): tokenizer = self.get_tokenizer() prompts = [ - tokenizer.apply_chat_template(message, - tokenize=False, - add_generation_template=add_generation_template, - chat_template=chat_template) + tokenizer.apply_chat_template( + message, + tokenize=False, + add_generation_template=add_generation_template, + chat_template=chat_template, + ) for message in messages ] - return self.generate(prompts, - sampling_params, - use_tqdm=use_tqdm, - lora_request=lora_request - ) + return self.generate( + prompts, + sampling_params, + use_tqdm=use_tqdm, + lora_request=lora_request, + ) @overload # LEGACY: single (prompt + optional token ids) def encode( self, prompts: str, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, + pooling_params: Optional[ + Union[PoolingParams, Sequence[PoolingParams]] + ] = None, prompt_token_ids: Optional[List[int]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, @@ -390,8 +418,9 @@ def encode( def encode( self, prompts: List[str], - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, + pooling_params: Optional[ + Union[PoolingParams, Sequence[PoolingParams]] + ] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, @@ -402,8 +431,9 @@ def encode( def encode( self, prompts: Optional[str] = None, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, + pooling_params: Optional[ + Union[PoolingParams, Sequence[PoolingParams]] + ] = None, *, prompt_token_ids: List[int], use_tqdm: bool = True, @@ -415,8 +445,9 @@ def encode( def encode( self, prompts: Optional[List[str]] = None, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, + pooling_params: Optional[ + Union[PoolingParams, Sequence[PoolingParams]] + ] = None, *, prompt_token_ids: List[List[int]], use_tqdm: bool = True, @@ -441,24 +472,29 @@ def encode( inputs: Union[PromptInputs, Sequence[PromptInputs]], /, # We may enable `inputs` keyword after removing the old API *, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, + pooling_params: Optional[ + Union[PoolingParams, Sequence[PoolingParams]] + ] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, ) -> List[EmbeddingRequestOutput]: ... - @deprecate_kwargs("prompts", - "prompt_token_ids", - is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'inputs' parameter " - "instead.") + @deprecate_kwargs( + "prompts", + "prompt_token_ids", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'inputs' parameter " "instead.", + ) def encode( self, - prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], - Optional[Union[str, List[str]]]] = None, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, + prompts: Union[ + Union[PromptInputs, Sequence[PromptInputs]], + Optional[Union[str, List[str]]], + ] = None, + pooling_params: Optional[ + Union[PoolingParams, Sequence[PoolingParams]] + ] = None, prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, @@ -478,7 +514,7 @@ def encode( use the default pooling parameters. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. - prompt_adapter_request: Prompt Adapter request to use for + prompt_adapter_request: Prompt Adapter request to use for generation, if any. Returns: @@ -536,15 +572,19 @@ def _convert_v1_inputs( if prompts is not None: num_requests = len(prompts) if prompt_token_ids is not None: - if (num_requests is not None - and num_requests != len(prompt_token_ids)): - raise ValueError("The lengths of prompts and prompt_token_ids " - "must be the same.") + if num_requests is not None and num_requests != len( + prompt_token_ids + ): + raise ValueError( + "The lengths of prompts and prompt_token_ids " + "must be the same." + ) num_requests = len(prompt_token_ids) if num_requests is None: - raise ValueError("Either prompts or prompt_token_ids must be " - "provided.") + raise ValueError( + "Either prompts or prompt_token_ids must be " "provided." + ) inputs: List[PromptInputs] = [] for i in range(num_requests): @@ -562,8 +602,12 @@ def _convert_v1_inputs( def _validate_and_add_requests( self, inputs: Union[PromptInputs, Sequence[PromptInputs]], - params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, - Sequence[PoolingParams]], + params: Union[ + SamplingParams, + Sequence[SamplingParams], + PoolingParams, + Sequence[PoolingParams], + ], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], prompt_adapter_request: Optional[PromptAdapterRequest], ) -> None: @@ -574,29 +618,31 @@ def _validate_and_add_requests( num_requests = len(inputs) if isinstance(params, list) and len(params) != num_requests: - raise ValueError("The lengths of prompts and params " - "must be the same.") - if isinstance(lora_request, - list) and len(lora_request) != num_requests: - raise ValueError("The lengths of prompts and lora_request " - "must be the same.") + raise ValueError( + "The lengths of prompts and params " "must be the same." + ) + if isinstance(lora_request, list) and len(lora_request) != num_requests: + raise ValueError( + "The lengths of prompts and lora_request " "must be the same." + ) # Add requests to the engine. for i, request_inputs in enumerate(inputs): self._add_request( request_inputs, params[i] if isinstance(params, Sequence) else params, - lora_request=lora_request[i] if isinstance( - lora_request, Sequence) else lora_request, - prompt_adapter_request=prompt_adapter_request) + lora_request=lora_request[i] + if isinstance(lora_request, Sequence) + else lora_request, + prompt_adapter_request=prompt_adapter_request, + ) def _add_request( - self, - inputs: PromptInputs, - params: Union[SamplingParams, PoolingParams], - lora_request: Optional[Union[List[LoRARequest], - LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + self, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: request_id = str(next(self.request_counter)) self.llm_engine.add_request( @@ -604,10 +650,11 @@ def _add_request( inputs, params, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + ) def _run_engine( - self, *, use_tqdm: bool + self, *, use_tqdm: bool ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # Initialize tqdm. if use_tqdm: @@ -616,8 +663,10 @@ def _run_engine( total=num_requests, desc="Processed prompts", dynamic_ncols=True, - postfix=(f"est. speed input: {0:.2f} toks/s, " - f"output: {0:.2f} toks/s"), + postfix=( + f"est. speed input: {0:.2f} toks/s, " + f"output: {0:.2f} toks/s" + ), ) # Run the engine. outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] @@ -634,12 +683,15 @@ def _run_engine( total_in_toks += len(output.prompt_token_ids) in_spd = total_in_toks / pbar.format_dict["elapsed"] total_out_toks += sum( - len(stp.token_ids) for stp in output.outputs) - out_spd = total_out_toks / pbar.format_dict[ - "elapsed"] + len(stp.token_ids) for stp in output.outputs + ) + out_spd = ( + total_out_toks / pbar.format_dict["elapsed"] + ) pbar.postfix = ( f"est. speed input: {in_spd:.2f} toks/s, " - f"output: {out_spd:.2f} toks/s") + f"output: {out_spd:.2f} toks/s" + ) pbar.update(1) if use_tqdm: pbar.close() From 23739361aca007249b0327642ae9dbdb0c76cfb1 Mon Sep 17 00:00:00 2001 From: nunjunj Date: Wed, 7 Aug 2024 07:59:57 -0700 Subject: [PATCH 04/12] apply parse_chat_messages --- tests/entrypoints/llm/test_generate.py | 4 ++-- vllm/entrypoints/llm.py | 25 +++++++++++++++++-------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index d994ae9758886..e32f676a32382 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -144,7 +144,7 @@ def test_multiple_sampling_params(llm: LLM): def test_chat(): - llm = LLM(model=MODEL_NAME) + llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct") prompt1 = "Explain the concept of entropy." messages = [ @@ -179,5 +179,5 @@ def test_chat(): SamplingParams(temperature=0.3, top_p=0.95), ] - outputs = llm.chat(messages, sampling_params=sampling_params) + outputs = llm.chat(multiple_messages, sampling_params=sampling_params) assert len(outputs) == len(messages) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index e65b1048096bb..a8ef394574993 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -352,7 +352,10 @@ def generate( def chat( self, - messages: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], + messages: Union[ + List[ChatCompletionMessageParam], + List[List[ChatCompletionMessageParam]] + ], sampling_params: Optional[ Union[SamplingParams, List[SamplingParams]] ] = None, @@ -386,35 +389,41 @@ def chat( """ tokenizer = self.get_tokenizer() - + model_config = self.llm_engine.get_model_config() + if isinstance(messages[0], dict): - # Apply chat templates for chat inputs. - prompts = tokenizer.apply_chat_template( + conversations, _ = parse_chat_messages( messages, + model_config, + tokenizer + ) + + prompts = tokenizer.apply_chat_template( + conversations, tokenize=False, add_generation_template=add_generation_template, chat_template=chat_template, ) elif isinstance(messages[0], list): - tokenizer = self.get_tokenizer() + prompts = [ tokenizer.apply_chat_template( - message, + parse_chat_messages(message, model_config, tokenizer)[0], tokenize=False, add_generation_template=add_generation_template, chat_template=chat_template, ) for message in messages ] - + return self.generate( prompts, sampling_params, use_tqdm=use_tqdm, lora_request=lora_request, ) - + @overload # LEGACY: single (prompt + optional token ids) def encode( self, From c4ca2f3cb2048d974917ce891a366051c258bbcd Mon Sep 17 00:00:00 2001 From: nunjunj Date: Thu, 8 Aug 2024 10:29:36 -0700 Subject: [PATCH 05/12] remove support for multiple chats --- examples/offline_inference_chat.py | 27 +----------- tests/entrypoints/llm/test_generate.py | 22 ---------- vllm/entrypoints/llm.py | 60 +++++++++----------------- 3 files changed, 21 insertions(+), 88 deletions(-) diff --git a/examples/offline_inference_chat.py b/examples/offline_inference_chat.py index aea8dedd55798..f92aa8ded36b9 100644 --- a/examples/offline_inference_chat.py +++ b/examples/offline_inference_chat.py @@ -14,9 +14,8 @@ def print_outputs(outputs): print("=" * 80) -# In this script, we demonstrate two ways to pass input to the chat method: +# In this script, we demonstrate how to pass input to the chat method: -# Conversation with a list of dictionaries conversation = [ {"role": "system", "content": "You are a helpful assistant"}, {"role": "user", "content": "Hello"}, @@ -31,30 +30,6 @@ def print_outputs(outputs): ) print_outputs(outputs) -# Multiple conversations -conversations = [ - [ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": "What is dark matter?"}, - ], - [ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": "How are you?"}, - { - "role": "assistant", - "content": "I'm an AI without feelings, but I'm here to help!", - }, - {"role": "user", "content": "Tell me a joke."}, - ], -] - -outputs = llm.chat( - conversations, - sampling_params=sampling_params, - use_tqdm=False, -) -print_outputs(outputs) - # A chat template can be optionally supplied. # If not, the model will use its default chat template. diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index e32f676a32382..c426e9b4ee899 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -159,25 +159,3 @@ def test_chat(): ] outputs = llm.chat(messages) assert len(outputs) == 1 - - prompt2 = "Describe Bangkok in 150 words." - multiple_messages = [messages] + [[ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": prompt2 - }, - ]] - outputs = llm.chat(multiple_messages) - assert len(outputs) == len(multiple_messages) - - sampling_params = [ - SamplingParams(temperature=0.01, top_p=0.95), - SamplingParams(temperature=0.3, top_p=0.95), - ] - - outputs = llm.chat(multiple_messages, sampling_params=sampling_params) - assert len(outputs) == len(messages) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d72ade2f24ad6..8b45fdbe77731 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import (ClassVar, Dict, List, Optional, Sequence, Union, cast, +from typing import (ClassVar, List, Optional, Sequence, Union, cast, overload) from tqdm import tqdm @@ -7,7 +7,9 @@ from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine -from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, parse_chat_messages +from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + parse_chat_messages, + apply_chat_template) from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt, parse_and_batch_prompt) from vllm.logger import init_logger @@ -360,13 +362,9 @@ def generate( def chat( self, - messages: Union[ - List[ChatCompletionMessageParam], - List[List[ChatCompletionMessageParam]] - ], - sampling_params: Optional[ - Union[SamplingParams, List[SamplingParams]] - ] = None, + messages: List[ChatCompletionMessageParam], + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, chat_template: Optional[str] = None, @@ -398,33 +396,17 @@ def chat( tokenizer = self.get_tokenizer() model_config = self.llm_engine.get_model_config() - - if isinstance(messages[0], dict): - conversations, _ = parse_chat_messages( - messages, - model_config, - tokenizer - ) - - prompts = tokenizer.apply_chat_template( - conversations, - tokenize=False, - add_generation_template=add_generation_template, - chat_template=chat_template, - ) - elif isinstance(messages[0], list): + conversations, _ = parse_chat_messages(messages, model_config, + tokenizer) + + prompts = apply_chat_template( + tokenizer, + conversations, + chat_template=chat_template, + add_generation_template=add_generation_template + ) - prompts = [ - tokenizer.apply_chat_template( - parse_chat_messages(message, model_config, tokenizer)[0], - tokenize=False, - add_generation_template=add_generation_template, - chat_template=chat_template, - ) - for message in messages - ] - return self.generate( prompts, sampling_params, @@ -436,9 +418,8 @@ def chat( def encode( self, prompts: str, - pooling_params: Optional[ - Union[PoolingParams, Sequence[PoolingParams]] - ] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, prompt_token_ids: Optional[List[int]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, @@ -672,9 +653,8 @@ def _validate_and_add_requests( self._add_request( request_inputs, params[i] if isinstance(params, Sequence) else params, - lora_request=lora_request[i] - if isinstance(lora_request, Sequence) - else lora_request, + lora_request=lora_request[i] if isinstance( + lora_request, Sequence) else lora_request, prompt_adapter_request=prompt_adapter_request, ) From c64cc8cb358b0f122c4991bfa02bfcc5668073ff Mon Sep 17 00:00:00 2001 From: nunjunj Date: Thu, 15 Aug 2024 21:52:31 +0700 Subject: [PATCH 06/12] fix lint --- vllm/entrypoints/llm.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 2ba37be8026aa..a0025cb8544e4 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,6 +1,5 @@ from contextlib import contextmanager -from typing import (ClassVar, List, Optional, Sequence, Union, cast, - overload) +from typing import ClassVar, List, Optional, Sequence, Union, cast, overload from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -8,10 +7,9 @@ from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, - parse_chat_messages, - apply_chat_template) -from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt, - parse_and_batch_prompt) + apply_chat_template, + parse_chat_messages) +from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest From 3cd737e5d0b1c7f5cb56acf59facf4c9fdb19c29 Mon Sep 17 00:00:00 2001 From: nunjunj Date: Thu, 15 Aug 2024 22:41:27 +0700 Subject: [PATCH 07/12] revert format changes --- vllm/entrypoints/llm.py | 142 ++++++++++++++++------------------------ 1 file changed, 55 insertions(+), 87 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a0025cb8544e4..1bfeed63583cf 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -149,8 +149,7 @@ def __init__( ) if any(k in kwargs for k in removed_vision_keys): raise TypeError( - "There is no need to pass vision-related arguments anymore." - ) + "There is no need to pass vision-related arguments anymore.") engine_args = EngineArgs( model=model, tokenizer=tokenizer, @@ -173,13 +172,11 @@ def __init__( **kwargs, ) self.llm_engine = LLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.LLM_CLASS - ) + engine_args, usage_context=UsageContext.LLM_CLASS) self.request_counter = Counter() def get_tokenizer( - self - ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: return self.llm_engine.tokenizer.tokenizer def set_tokenizer( @@ -193,16 +190,14 @@ def set_tokenizer( self.llm_engine.tokenizer.tokenizer = tokenizer else: self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer( - tokenizer - ) + tokenizer) @overload # LEGACY: single (prompt + optional token ids) def generate( self, prompts: str, - sampling_params: Optional[ - Union[SamplingParams, List[SamplingParams]] - ] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, prompt_token_ids: Optional[List[int]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, @@ -213,9 +208,8 @@ def generate( def generate( self, prompts: List[str], - sampling_params: Optional[ - Union[SamplingParams, List[SamplingParams]] - ] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, @@ -226,9 +220,8 @@ def generate( def generate( self, prompts: Optional[str] = None, - sampling_params: Optional[ - Union[SamplingParams, List[SamplingParams]] - ] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, *, prompt_token_ids: List[int], use_tqdm: bool = True, @@ -240,9 +233,8 @@ def generate( def generate( self, prompts: Optional[List[str]] = None, - sampling_params: Optional[ - Union[SamplingParams, List[SamplingParams]] - ] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, *, prompt_token_ids: List[List[int]], use_tqdm: bool = True, @@ -267,9 +259,8 @@ def generate( inputs: Union[PromptInputs, Sequence[PromptInputs]], /, # We may enable `inputs` keyword after removing the old API *, - sampling_params: Optional[ - Union[SamplingParams, Sequence[SamplingParams]] - ] = None, + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, ) -> List[RequestOutput]: @@ -279,17 +270,15 @@ def generate( "prompts", "prompt_token_ids", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'inputs' parameter " "instead.", + additional_message="Please use the 'inputs' parameter " + "instead.", ) def generate( self, - prompts: Union[ - Union[PromptInputs, Sequence[PromptInputs]], - Optional[Union[str, List[str]]], - ] = None, - sampling_params: Optional[ - Union[SamplingParams, Sequence[SamplingParams]] - ] = None, + prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], + Optional[Union[str, List[str]]], ] = None, + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, @@ -403,8 +392,7 @@ def chat( tokenizer, conversations, chat_template=chat_template, - add_generation_template=add_generation_template - ) + add_generation_template=add_generation_template) return self.generate( prompts, @@ -429,9 +417,8 @@ def encode( def encode( self, prompts: List[str], - pooling_params: Optional[ - Union[PoolingParams, Sequence[PoolingParams]] - ] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, @@ -442,9 +429,8 @@ def encode( def encode( self, prompts: Optional[str] = None, - pooling_params: Optional[ - Union[PoolingParams, Sequence[PoolingParams]] - ] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, *, prompt_token_ids: List[int], use_tqdm: bool = True, @@ -456,9 +442,8 @@ def encode( def encode( self, prompts: Optional[List[str]] = None, - pooling_params: Optional[ - Union[PoolingParams, Sequence[PoolingParams]] - ] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, *, prompt_token_ids: List[List[int]], use_tqdm: bool = True, @@ -483,9 +468,8 @@ def encode( inputs: Union[PromptInputs, Sequence[PromptInputs]], /, # We may enable `inputs` keyword after removing the old API *, - pooling_params: Optional[ - Union[PoolingParams, Sequence[PoolingParams]] - ] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, ) -> List[EmbeddingRequestOutput]: @@ -495,17 +479,15 @@ def encode( "prompts", "prompt_token_ids", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'inputs' parameter " "instead.", + additional_message="Please use the 'inputs' parameter " + "instead.", ) def encode( self, - prompts: Union[ - Union[PromptInputs, Sequence[PromptInputs]], - Optional[Union[str, List[str]]], - ] = None, - pooling_params: Optional[ - Union[PoolingParams, Sequence[PoolingParams]] - ] = None, + prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], + Optional[Union[str, List[str]]], ] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, @@ -584,18 +566,14 @@ def _convert_v1_inputs( num_requests = len(prompts) if prompt_token_ids is not None: if num_requests is not None and num_requests != len( - prompt_token_ids - ): - raise ValueError( - "The lengths of prompts and prompt_token_ids " - "must be the same." - ) + prompt_token_ids): + raise ValueError("The lengths of prompts and prompt_token_ids " + "must be the same.") num_requests = len(prompt_token_ids) if num_requests is None: - raise ValueError( - "Either prompts or prompt_token_ids must be " "provided." - ) + raise ValueError("Either prompts or prompt_token_ids must be " + "provided.") inputs: List[PromptInputs] = [] for i in range(num_requests): @@ -613,12 +591,8 @@ def _convert_v1_inputs( def _validate_and_add_requests( self, inputs: Union[PromptInputs, Sequence[PromptInputs]], - params: Union[ - SamplingParams, - Sequence[SamplingParams], - PoolingParams, - Sequence[PoolingParams], - ], + params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, + Sequence[PoolingParams], ], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], prompt_adapter_request: Optional[PromptAdapterRequest], guided_options: Optional[GuidedDecodingRequest] = None, @@ -630,13 +604,12 @@ def _validate_and_add_requests( num_requests = len(inputs) if isinstance(params, list) and len(params) != num_requests: - raise ValueError( - "The lengths of prompts and params " "must be the same." - ) - if isinstance(lora_request, list) and len(lora_request) != num_requests: - raise ValueError( - "The lengths of prompts and lora_request " "must be the same." - ) + raise ValueError("The lengths of prompts and params " + "must be the same.") + if isinstance(lora_request, + list) and len(lora_request) != num_requests: + raise ValueError("The lengths of prompts and lora_request " + "must be the same.") if isinstance(params, list): params = [ @@ -692,7 +665,7 @@ def _add_guided_processor( return params def _run_engine( - self, *, use_tqdm: bool + self, *, use_tqdm: bool ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # Initialize tqdm. if use_tqdm: @@ -701,10 +674,8 @@ def _run_engine( total=num_requests, desc="Processed prompts", dynamic_ncols=True, - postfix=( - f"est. speed input: {0:.2f} toks/s, " - f"output: {0:.2f} toks/s" - ), + postfix=(f"est. speed input: {0:.2f} toks/s, " + f"output: {0:.2f} toks/s"), ) # Run the engine. outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] @@ -721,15 +692,12 @@ def _run_engine( total_in_toks += len(output.prompt_token_ids) in_spd = total_in_toks / pbar.format_dict["elapsed"] total_out_toks += sum( - len(stp.token_ids) for stp in output.outputs - ) - out_spd = ( - total_out_toks / pbar.format_dict["elapsed"] - ) + len(stp.token_ids) for stp in output.outputs) + out_spd = (total_out_toks / + pbar.format_dict["elapsed"]) pbar.postfix = ( f"est. speed input: {in_spd:.2f} toks/s, " - f"output: {out_spd:.2f} toks/s" - ) + f"output: {out_spd:.2f} toks/s") pbar.update(1) if use_tqdm: pbar.close() From 9cead6b1db5c56d5efe1129caefcc0a10368a47b Mon Sep 17 00:00:00 2001 From: nunjunj Date: Thu, 15 Aug 2024 22:49:28 +0700 Subject: [PATCH 08/12] fix lint --- examples/offline_inference_chat.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/examples/offline_inference_chat.py b/examples/offline_inference_chat.py index f92aa8ded36b9..c2020724c72fe 100644 --- a/examples/offline_inference_chat.py +++ b/examples/offline_inference_chat.py @@ -17,17 +17,26 @@ def print_outputs(outputs): # In this script, we demonstrate how to pass input to the chat method: conversation = [ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hello! How can I assist you today?"}, + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "Hello" + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, { "role": "user", "content": "Write an essay about the importance of higher education.", }, ] -outputs = llm.chat( - conversation, sampling_params=sampling_params, use_tqdm=False -) +outputs = llm.chat(conversation, + sampling_params=sampling_params, + use_tqdm=False) print_outputs(outputs) # A chat template can be optionally supplied. From 20b38d9f7d2fcce82157938c1fb5fe845d785cae Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 16 Aug 2024 00:07:06 +0800 Subject: [PATCH 09/12] Reduce diffs --- vllm/entrypoints/llm.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 1bfeed63583cf..bf84bd74de519 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -276,7 +276,7 @@ def generate( def generate( self, prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], - Optional[Union[str, List[str]]], ] = None, + Optional[Union[str, List[str]]]] = None, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, @@ -485,7 +485,7 @@ def encode( def encode( self, prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], - Optional[Union[str, List[str]]], ] = None, + Optional[Union[str, List[str]]]] = None, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, @@ -565,8 +565,8 @@ def _convert_v1_inputs( if prompts is not None: num_requests = len(prompts) if prompt_token_ids is not None: - if num_requests is not None and num_requests != len( - prompt_token_ids): + if (num_requests is not None + and num_requests != len(prompt_token_ids)): raise ValueError("The lengths of prompts and prompt_token_ids " "must be the same.") @@ -592,7 +592,7 @@ def _validate_and_add_requests( self, inputs: Union[PromptInputs, Sequence[PromptInputs]], params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, - Sequence[PoolingParams], ], + Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], prompt_adapter_request: Optional[PromptAdapterRequest], guided_options: Optional[GuidedDecodingRequest] = None, From 415193dc04c351041c130f1dfd3573a30fa2354d Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 16 Aug 2024 00:07:52 +0800 Subject: [PATCH 10/12] Clean --- vllm/entrypoints/llm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index bf84bd74de519..097e547bbc15f 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -479,8 +479,7 @@ def encode( "prompts", "prompt_token_ids", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'inputs' parameter " - "instead.", + additional_message="Please use the 'inputs' parameter instead.", ) def encode( self, From 7eb64e2dede2bfd6720d86eede96b6bb97bc1c9b Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 16 Aug 2024 00:27:59 +0800 Subject: [PATCH 11/12] Clean 2 --- vllm/entrypoints/llm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 097e547bbc15f..219cadf5a6db1 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -270,8 +270,7 @@ def generate( "prompts", "prompt_token_ids", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'inputs' parameter " - "instead.", + additional_message="Please use the 'inputs' parameter instead.", ) def generate( self, From 64993b6cbeef6965fb051cbf6e3bce624445e04f Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 16 Aug 2024 00:13:13 +0000 Subject: [PATCH 12/12] Fix docs --- vllm/entrypoints/llm.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 219cadf5a6db1..32bdb2b7d14f4 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -304,7 +304,7 @@ def generate( generation, if any. Returns: - A list of `RequestOutput` objects containing the + A list of ``RequestOutput`` objects containing the generated completions in the same order as the input prompts. Note: @@ -359,8 +359,10 @@ def chat( ) -> List[RequestOutput]: """ Generates responses for chat messages. + Converts the messages to prompts using the tokenizer and calls - the `generate` method to generate the responses. + the :meth:`generate` method to generate the responses. + Args: messages: A list of messages to generate responses for. Each message is a list of dictionaries with 'role' and 'content' @@ -376,8 +378,9 @@ def chat( If not provided, the model's default chat template will be used. add_generation_template: If True, adds a generation template to each message. + Returns: - A list of `RequestOutput` objects containing the generated + A list of ``RequestOutput`` objects containing the generated responses in the same order as the input messages. """