From c49532cb0822bf05118462ae6e21beb6f9ea533c Mon Sep 17 00:00:00 2001 From: nunjunj Date: Tue, 23 Jul 2024 10:05:40 -0700 Subject: [PATCH] run format.sh --- examples/offline_inference_chat.py | 60 ++---- tests/entrypoints/llm/test_generate.py | 7 +- vllm/entrypoints/llm.py | 284 +++++++++++++++---------- 3 files changed, 195 insertions(+), 156 deletions(-) diff --git a/examples/offline_inference_chat.py b/examples/offline_inference_chat.py index b4d463578406b..aea8dedd55798 100644 --- a/examples/offline_inference_chat.py +++ b/examples/offline_inference_chat.py @@ -14,63 +14,37 @@ def print_outputs(outputs): print("=" * 80) -# In this script, we demonstrate two ways to pass input to the chat method of the LLM class: +# In this script, we demonstrate two ways to pass input to the chat method: # Conversation with a list of dictionaries conversation = [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hello! How can I assist you today?"}, { - 'role': 'system', - 'content': "You are a helpful assistant" - }, - { - 'role': 'user', - 'content': "Hello" - }, - { - 'role': 'assistant', - 'content': "Hello! How can I assist you today?" - }, - { - 'role': 'user', - 'content': "Write an essay about the importance of higher education." + "role": "user", + "content": "Write an essay about the importance of higher education.", }, ] -outputs = llm.chat(conversation, - sampling_params=sampling_params, - use_tqdm=False) +outputs = llm.chat( + conversation, sampling_params=sampling_params, use_tqdm=False +) print_outputs(outputs) -# Multiple conversations +# Multiple conversations conversations = [ [ - { - 'role': 'system', - 'content': "You are a helpful assistant" - }, - { - 'role': 'user', - 'content': "What is dark matter?" - }, + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "What is dark matter?"}, ], [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "How are you?"}, { - 'role': 'system', - 'content': "You are a helpful assistant" - }, - { - 'role': 'user', - 'content': "How are you?" - }, - { - 'role': - 'assistant', - 'content': - "I'm an AI, so I don't have feelings, but I'm here to help you!" - }, - { - 'role': 'user', - 'content': "Tell me a joke." + "role": "assistant", + "content": "I'm an AI without feelings, but I'm here to help!", }, + {"role": "user", "content": "Tell me a joke."}, ], ] diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index 6e94f684b155d..d994ae9758886 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -141,6 +141,7 @@ def test_multiple_sampling_params(llm: LLM): outputs = llm.generate(PROMPTS, sampling_params=None) assert len(PROMPTS) == len(outputs) + def test_chat(): llm = LLM(model=MODEL_NAME) @@ -160,7 +161,7 @@ def test_chat(): assert len(outputs) == 1 prompt2 = "Describe Bangkok in 150 words." - messages = [messages] + [[ + multiple_messages = [messages] + [[ { "role": "system", "content": "You are a helpful assistant" @@ -170,8 +171,8 @@ def test_chat(): "content": prompt2 }, ]] - outputs = llm.chat(messages) - assert len(outputs) == len(messages) + outputs = llm.chat(multiple_messages) + assert len(outputs) == len(multiple_messages) sampling_params = [ SamplingParams(temperature=0.01, top_p=0.95), diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index aa804f636a8d9..ef8ba0ea70f45 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,13 +1,26 @@ from contextlib import contextmanager -from typing import ClassVar, Dict, List, Optional, Sequence, Union, cast, overload +from typing import ( + ClassVar, + Dict, + List, + Optional, + Sequence, + Union, + cast, + overload, +) from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine -from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt, - parse_and_batch_prompt) +from vllm.inputs import ( + PromptInputs, + TextPrompt, + TokensPrompt, + parse_and_batch_prompt, +) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -84,7 +97,7 @@ class LLM: disable_custom_all_reduce: See ParallelConfig **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See :ref:`engine_args`) - + Note: This class is intended to be used for offline inference. For online serving, use the :class:`~vllm.AsyncLLMEngine` class instead. @@ -126,11 +139,16 @@ def __init__( ) -> None: if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True - removed_vision_keys = ("image_token_id", "image_feature_size", - "image_input_shape", "image_input_type") + removed_vision_keys = ( + "image_token_id", + "image_feature_size", + "image_input_shape", + "image_input_type", + ) if any(k in kwargs for k in removed_vision_keys): raise TypeError( - "There is no need to pass vision-related arguments anymore.") + "There is no need to pass vision-related arguments anymore." + ) engine_args = EngineArgs( model=model, tokenizer=tokenizer, @@ -153,11 +171,13 @@ def __init__( **kwargs, ) self.llm_engine = LLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.LLM_CLASS) + engine_args, usage_context=UsageContext.LLM_CLASS + ) self.request_counter = Counter() def get_tokenizer( - self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + self + ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: return self.llm_engine.tokenizer.tokenizer def set_tokenizer( @@ -171,14 +191,16 @@ def set_tokenizer( self.llm_engine.tokenizer.tokenizer = tokenizer else: self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer( - tokenizer) + tokenizer + ) @overload # LEGACY: single (prompt + optional token ids) def generate( self, prompts: str, - sampling_params: Optional[Union[SamplingParams, - List[SamplingParams]]] = None, + sampling_params: Optional[ + Union[SamplingParams, List[SamplingParams]] + ] = None, prompt_token_ids: Optional[List[int]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, @@ -189,8 +211,9 @@ def generate( def generate( self, prompts: List[str], - sampling_params: Optional[Union[SamplingParams, - List[SamplingParams]]] = None, + sampling_params: Optional[ + Union[SamplingParams, List[SamplingParams]] + ] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, @@ -201,8 +224,9 @@ def generate( def generate( self, prompts: Optional[str] = None, - sampling_params: Optional[Union[SamplingParams, - List[SamplingParams]]] = None, + sampling_params: Optional[ + Union[SamplingParams, List[SamplingParams]] + ] = None, *, prompt_token_ids: List[int], use_tqdm: bool = True, @@ -214,8 +238,9 @@ def generate( def generate( self, prompts: Optional[List[str]] = None, - sampling_params: Optional[Union[SamplingParams, - List[SamplingParams]]] = None, + sampling_params: Optional[ + Union[SamplingParams, List[SamplingParams]] + ] = None, *, prompt_token_ids: List[List[int]], use_tqdm: bool = True, @@ -240,24 +265,29 @@ def generate( inputs: Union[PromptInputs, Sequence[PromptInputs]], /, # We may enable `inputs` keyword after removing the old API *, - sampling_params: Optional[Union[SamplingParams, - Sequence[SamplingParams]]] = None, + sampling_params: Optional[ + Union[SamplingParams, Sequence[SamplingParams]] + ] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, ) -> List[RequestOutput]: ... - @deprecate_kwargs("prompts", - "prompt_token_ids", - is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'inputs' parameter " - "instead.") + @deprecate_kwargs( + "prompts", + "prompt_token_ids", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'inputs' parameter " "instead.", + ) def generate( self, - prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], - Optional[Union[str, List[str]]]] = None, - sampling_params: Optional[Union[SamplingParams, - Sequence[SamplingParams]]] = None, + prompts: Union[ + Union[PromptInputs, Sequence[PromptInputs]], + Optional[Union[str, List[str]]], + ] = None, + sampling_params: Optional[ + Union[SamplingParams, Sequence[SamplingParams]] + ] = None, prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, @@ -272,13 +302,13 @@ def generate( Args: inputs: A list of inputs to generate completions for. sampling_params: The sampling parameters for text generation. If - None, we use the default sampling parameters. - When it is a single value, it is applied to every prompt. - When it is a list, the list must have the same length as the + None, we use the default sampling parameters. + When it is a single value, it is applied to every prompt. + When it is a list, the list must have the same length as the prompts and it is paired one by one with the prompt. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. - prompt_adapter_request: Prompt Adapter request to use for + prompt_adapter_request: Prompt Adapter request to use for generation, if any. Returns: @@ -293,7 +323,8 @@ def generate( if self.llm_engine.model_config.embedding_mode: raise ValueError( "LLM.generate() is only supported for generation models " - "(XForCausalLM).") + "(XForCausalLM)." + ) if prompt_token_ids is not None: inputs = self._convert_v1_inputs( @@ -311,16 +342,18 @@ def generate( inputs=inputs, params=sampling_params, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + ) outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) - + def chat( self, messages: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], - sampling_params: Optional[Union[SamplingParams, - List[SamplingParams]]] = None, + sampling_params: Optional[ + Union[SamplingParams, List[SamplingParams]] + ] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, chat_template: Optional[str] = None, @@ -328,24 +361,25 @@ def chat( ) -> List[RequestOutput]: """ Generates responses for chat messages. - Converts the messages to prompts using the tokenizer and calls + Converts the messages to prompts using the tokenizer and calls the `generate` method to generate the responses. Args: - messages: A list of messages to generate responses for. Each - message is a list of dictionaries with 'role' and 'content' + messages: A list of messages to generate responses for. Each + message is a list of dictionaries with 'role' and 'content' keys. - sampling_params: The sampling parameters for text generation. - If None, we use the default sampling parameters. When it - is a single value, it is applied to every prompt. When it - is a list, the list must have the same length as the + sampling_params: The sampling parameters for text generation. + If None, we use the default sampling parameters. When it + is a single value, it is applied to every prompt. When it + is a list, the list must have the same length as the prompts and it is paired one by one with the prompt. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. chat_template: The template to use for structuring the chat. If not provided, the model's default chat template will be used. - add_generation_template: If True, adds a generation template to each message. + add_generation_template: If True, adds a generation template + to each message. Returns: - A list of `RequestOutput` objects containing the generated + A list of `RequestOutput` objects containing the generated responses in the same order as the input messages. """ @@ -354,32 +388,38 @@ def chat( if isinstance(messages[0], dict): # Apply chat templates for chat inputs. prompts = tokenizer.apply_chat_template( - messages, tokenize=False, + messages, + tokenize=False, add_generation_template=add_generation_template, - chat_template=chat_template) - + chat_template=chat_template, + ) + elif isinstance(messages[0], list): tokenizer = self.get_tokenizer() prompts = [ - tokenizer.apply_chat_template(message, - tokenize=False, - add_generation_template=add_generation_template, - chat_template=chat_template) + tokenizer.apply_chat_template( + message, + tokenize=False, + add_generation_template=add_generation_template, + chat_template=chat_template, + ) for message in messages ] - return self.generate(prompts, - sampling_params, - use_tqdm=use_tqdm, - lora_request=lora_request - ) + return self.generate( + prompts, + sampling_params, + use_tqdm=use_tqdm, + lora_request=lora_request, + ) @overload # LEGACY: single (prompt + optional token ids) def encode( self, prompts: str, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, + pooling_params: Optional[ + Union[PoolingParams, Sequence[PoolingParams]] + ] = None, prompt_token_ids: Optional[List[int]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, @@ -390,8 +430,9 @@ def encode( def encode( self, prompts: List[str], - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, + pooling_params: Optional[ + Union[PoolingParams, Sequence[PoolingParams]] + ] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, @@ -402,8 +443,9 @@ def encode( def encode( self, prompts: Optional[str] = None, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, + pooling_params: Optional[ + Union[PoolingParams, Sequence[PoolingParams]] + ] = None, *, prompt_token_ids: List[int], use_tqdm: bool = True, @@ -415,8 +457,9 @@ def encode( def encode( self, prompts: Optional[List[str]] = None, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, + pooling_params: Optional[ + Union[PoolingParams, Sequence[PoolingParams]] + ] = None, *, prompt_token_ids: List[List[int]], use_tqdm: bool = True, @@ -441,24 +484,29 @@ def encode( inputs: Union[PromptInputs, Sequence[PromptInputs]], /, # We may enable `inputs` keyword after removing the old API *, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, + pooling_params: Optional[ + Union[PoolingParams, Sequence[PoolingParams]] + ] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, ) -> List[EmbeddingRequestOutput]: ... - @deprecate_kwargs("prompts", - "prompt_token_ids", - is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'inputs' parameter " - "instead.") + @deprecate_kwargs( + "prompts", + "prompt_token_ids", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'inputs' parameter " "instead.", + ) def encode( self, - prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], - Optional[Union[str, List[str]]]] = None, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, + prompts: Union[ + Union[PromptInputs, Sequence[PromptInputs]], + Optional[Union[str, List[str]]], + ] = None, + pooling_params: Optional[ + Union[PoolingParams, Sequence[PoolingParams]] + ] = None, prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, @@ -478,7 +526,7 @@ def encode( use the default pooling parameters. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. - prompt_adapter_request: Prompt Adapter request to use for + prompt_adapter_request: Prompt Adapter request to use for generation, if any. Returns: @@ -536,15 +584,19 @@ def _convert_v1_inputs( if prompts is not None: num_requests = len(prompts) if prompt_token_ids is not None: - if (num_requests is not None - and num_requests != len(prompt_token_ids)): - raise ValueError("The lengths of prompts and prompt_token_ids " - "must be the same.") + if num_requests is not None and num_requests != len( + prompt_token_ids + ): + raise ValueError( + "The lengths of prompts and prompt_token_ids " + "must be the same." + ) num_requests = len(prompt_token_ids) if num_requests is None: - raise ValueError("Either prompts or prompt_token_ids must be " - "provided.") + raise ValueError( + "Either prompts or prompt_token_ids must be " "provided." + ) inputs: List[PromptInputs] = [] for i in range(num_requests): @@ -562,8 +614,12 @@ def _convert_v1_inputs( def _validate_and_add_requests( self, inputs: Union[PromptInputs, Sequence[PromptInputs]], - params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, - Sequence[PoolingParams]], + params: Union[ + SamplingParams, + Sequence[SamplingParams], + PoolingParams, + Sequence[PoolingParams], + ], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], prompt_adapter_request: Optional[PromptAdapterRequest], ) -> None: @@ -574,29 +630,31 @@ def _validate_and_add_requests( num_requests = len(inputs) if isinstance(params, list) and len(params) != num_requests: - raise ValueError("The lengths of prompts and params " - "must be the same.") - if isinstance(lora_request, - list) and len(lora_request) != num_requests: - raise ValueError("The lengths of prompts and lora_request " - "must be the same.") + raise ValueError( + "The lengths of prompts and params " "must be the same." + ) + if isinstance(lora_request, list) and len(lora_request) != num_requests: + raise ValueError( + "The lengths of prompts and lora_request " "must be the same." + ) # Add requests to the engine. for i, request_inputs in enumerate(inputs): self._add_request( request_inputs, params[i] if isinstance(params, Sequence) else params, - lora_request=lora_request[i] if isinstance( - lora_request, Sequence) else lora_request, - prompt_adapter_request=prompt_adapter_request) + lora_request=lora_request[i] + if isinstance(lora_request, Sequence) + else lora_request, + prompt_adapter_request=prompt_adapter_request, + ) def _add_request( - self, - inputs: PromptInputs, - params: Union[SamplingParams, PoolingParams], - lora_request: Optional[Union[List[LoRARequest], - LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + self, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: request_id = str(next(self.request_counter)) self.llm_engine.add_request( @@ -604,10 +662,11 @@ def _add_request( inputs, params, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + ) def _run_engine( - self, *, use_tqdm: bool + self, *, use_tqdm: bool ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # Initialize tqdm. if use_tqdm: @@ -616,8 +675,10 @@ def _run_engine( total=num_requests, desc="Processed prompts", dynamic_ncols=True, - postfix=(f"est. speed input: {0:.2f} toks/s, " - f"output: {0:.2f} toks/s"), + postfix=( + f"est. speed input: {0:.2f} toks/s, " + f"output: {0:.2f} toks/s" + ), ) # Run the engine. outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] @@ -634,12 +695,15 @@ def _run_engine( total_in_toks += len(output.prompt_token_ids) in_spd = total_in_toks / pbar.format_dict["elapsed"] total_out_toks += sum( - len(stp.token_ids) for stp in output.outputs) - out_spd = total_out_toks / pbar.format_dict[ - "elapsed"] + len(stp.token_ids) for stp in output.outputs + ) + out_spd = ( + total_out_toks / pbar.format_dict["elapsed"] + ) pbar.postfix = ( f"est. speed input: {in_spd:.2f} toks/s, " - f"output: {out_spd:.2f} toks/s") + f"output: {out_spd:.2f} toks/s" + ) pbar.update(1) if use_tqdm: pbar.close()