Skip to content

Commit

Permalink
Chat method for offline llm (#5049)
Browse files Browse the repository at this point in the history
Co-authored-by: nunjunj <[email protected]>
Co-authored-by: nunjunj <[email protected]>
Co-authored-by: nunjunj <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: DarkLight1337 <[email protected]>
  • Loading branch information
6 people authored Aug 16, 2024
1 parent 4cd7d47 commit 3b19e39
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 29 deletions.
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ steps:
- pip install awscli tensorizer # for llava example and tensorizer test
- python3 offline_inference.py
- python3 cpu_offload.py
- python3 offline_inference_chat.py
- python3 offline_inference_with_prefix.py
- python3 llm_engine_example.py
- python3 offline_inference_vision_language.py
Expand Down
53 changes: 53 additions & 0 deletions examples/offline_inference_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from vllm import LLM, SamplingParams

llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
sampling_params = SamplingParams(temperature=0.5)


def print_outputs(outputs):
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print("-" * 80)


print("=" * 80)

# In this script, we demonstrate how to pass input to the chat method:

conversation = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": "Hello"
},
{
"role": "assistant",
"content": "Hello! How can I assist you today?"
},
{
"role": "user",
"content": "Write an essay about the importance of higher education.",
},
]
outputs = llm.chat(conversation,
sampling_params=sampling_params,
use_tqdm=False)
print_outputs(outputs)

# A chat template can be optionally supplied.
# If not, the model will use its default chat template.

# with open('template_falcon_180b.jinja', "r") as f:
# chat_template = f.read()

# outputs = llm.chat(
# conversations,
# sampling_params=sampling_params,
# use_tqdm=False,
# chat_template=chat_template,
# )
19 changes: 19 additions & 0 deletions tests/entrypoints/llm/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,22 @@ def test_multiple_sampling_params(llm: LLM):
# sampling_params is None, default params should be applied
outputs = llm.generate(PROMPTS, sampling_params=None)
assert len(PROMPTS) == len(outputs)


def test_chat():

llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")

prompt1 = "Explain the concept of entropy."
messages = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
outputs = llm.chat(messages)
assert len(outputs) == 1
124 changes: 95 additions & 29 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_chat_template,
parse_chat_messages)
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
Expand Down Expand Up @@ -87,7 +90,7 @@ class LLM:
disable_custom_all_reduce: See ParallelConfig
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`)
Note:
This class is intended to be used for offline inference. For online
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
Expand Down Expand Up @@ -138,8 +141,12 @@ def __init__(

if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
removed_vision_keys = ("image_token_id", "image_feature_size",
"image_input_shape", "image_input_type")
removed_vision_keys = (
"image_token_id",
"image_feature_size",
"image_input_shape",
"image_input_type",
)
if any(k in kwargs for k in removed_vision_keys):
raise TypeError(
"There is no need to pass vision-related arguments anymore.")
Expand Down Expand Up @@ -259,11 +266,12 @@ def generate(
) -> List[RequestOutput]:
...

@deprecate_kwargs("prompts",
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter "
"instead.")
@deprecate_kwargs(
"prompts",
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter instead.",
)
def generate(
self,
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
Expand All @@ -286,17 +294,17 @@ def generate(
Args:
inputs: A list of inputs to generate completions for.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
When it is a list, the list must have the same length as the
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
When it is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns:
A list of `RequestOutput` objects containing the
A list of ``RequestOutput`` objects containing the
generated completions in the same order as the input prompts.
Note:
Expand Down Expand Up @@ -339,6 +347,62 @@ def generate(
outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput)

def chat(
self,
messages: List[ChatCompletionMessageParam],
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None,
add_generation_template: bool = True,
) -> List[RequestOutput]:
"""
Generates responses for chat messages.
Converts the messages to prompts using the tokenizer and calls
the :meth:`generate` method to generate the responses.
Args:
messages: A list of messages to generate responses for. Each
message is a list of dictionaries with 'role' and 'content'
keys.
sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it
is a single value, it is applied to every prompt. When it
is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
chat_template: The template to use for structuring the chat.
If not provided, the model's default chat template will be used.
add_generation_template: If True, adds a generation template
to each message.
Returns:
A list of ``RequestOutput`` objects containing the generated
responses in the same order as the input messages.
"""

tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config()

conversations, _ = parse_chat_messages(messages, model_config,
tokenizer)

prompts = apply_chat_template(
tokenizer,
conversations,
chat_template=chat_template,
add_generation_template=add_generation_template)

return self.generate(
prompts,
sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
)

@overload # LEGACY: single (prompt + optional token ids)
def encode(
self,
Expand Down Expand Up @@ -413,11 +477,12 @@ def encode(
) -> List[EmbeddingRequestOutput]:
...

@deprecate_kwargs("prompts",
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter "
"instead.")
@deprecate_kwargs(
"prompts",
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter instead.",
)
def encode(
self,
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
Expand All @@ -443,7 +508,7 @@ def encode(
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns:
Expand Down Expand Up @@ -563,23 +628,24 @@ def _validate_and_add_requests(
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
)

def _add_request(
self,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[Union[List[LoRARequest],
LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
self,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(
request_id,
inputs,
params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
)

def _add_guided_processor(
self,
Expand Down Expand Up @@ -628,8 +694,8 @@ def _run_engine(
in_spd = total_in_toks / pbar.format_dict["elapsed"]
total_out_toks += sum(
len(stp.token_ids) for stp in output.outputs)
out_spd = total_out_toks / pbar.format_dict[
"elapsed"]
out_spd = (total_out_toks /
pbar.format_dict["elapsed"])
pbar.postfix = (
f"est. speed input: {in_spd:.2f} toks/s, "
f"output: {out_spd:.2f} toks/s")
Expand Down

0 comments on commit 3b19e39

Please sign in to comment.