Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat method for offline llm #5049

Merged
merged 16 commits into from
Aug 16, 2024
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ steps:
- pip install awscli tensorizer # for llava example and tensorizer test
- python3 offline_inference.py
- python3 cpu_offload.py
- python3 offline_inference_chat.py
- python3 offline_inference_with_prefix.py
- python3 llm_engine_example.py
- python3 offline_inference_vision_language.py
Expand Down
53 changes: 53 additions & 0 deletions examples/offline_inference_chat.py
nunjunj marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from vllm import LLM, SamplingParams

llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
sampling_params = SamplingParams(temperature=0.5)


def print_outputs(outputs):
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print("-" * 80)


print("=" * 80)

# In this script, we demonstrate how to pass input to the chat method:

conversation = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": "Hello"
},
{
"role": "assistant",
"content": "Hello! How can I assist you today?"
},
{
"role": "user",
"content": "Write an essay about the importance of higher education.",
},
]
outputs = llm.chat(conversation,
sampling_params=sampling_params,
use_tqdm=False)
print_outputs(outputs)

# A chat template can be optionally supplied.
# If not, the model will use its default chat template.

# with open('template_falcon_180b.jinja', "r") as f:
# chat_template = f.read()

# outputs = llm.chat(
# conversations,
# sampling_params=sampling_params,
# use_tqdm=False,
# chat_template=chat_template,
# )
19 changes: 19 additions & 0 deletions tests/entrypoints/llm/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,22 @@ def test_multiple_sampling_params(llm: LLM):
# sampling_params is None, default params should be applied
outputs = llm.generate(PROMPTS, sampling_params=None)
assert len(PROMPTS) == len(outputs)


def test_chat():

llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")

prompt1 = "Explain the concept of entropy."
messages = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
outputs = llm.chat(messages)
assert len(outputs) == 1
121 changes: 93 additions & 28 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_chat_template,
parse_chat_messages)
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
Expand Down Expand Up @@ -87,7 +90,7 @@ class LLM:
disable_custom_all_reduce: See ParallelConfig
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`)

Note:
This class is intended to be used for offline inference. For online
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
Expand Down Expand Up @@ -138,8 +141,12 @@ def __init__(

if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
removed_vision_keys = ("image_token_id", "image_feature_size",
"image_input_shape", "image_input_type")
removed_vision_keys = (
"image_token_id",
"image_feature_size",
"image_input_shape",
"image_input_type",
)
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
if any(k in kwargs for k in removed_vision_keys):
raise TypeError(
"There is no need to pass vision-related arguments anymore.")
Expand Down Expand Up @@ -259,11 +266,13 @@ def generate(
) -> List[RequestOutput]:
...

@deprecate_kwargs("prompts",
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter "
"instead.")
@deprecate_kwargs(
"prompts",
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter "
"instead.",
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
)
def generate(
self,
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
Expand All @@ -286,13 +295,13 @@ def generate(
Args:
inputs: A list of inputs to generate completions for.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
When it is a list, the list must have the same length as the
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
When it is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.

Returns:
Expand Down Expand Up @@ -339,6 +348,59 @@ def generate(
outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput)

def chat(
self,
messages: List[ChatCompletionMessageParam],
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None,
add_generation_template: bool = True,
) -> List[RequestOutput]:
"""
Generates responses for chat messages.
Converts the messages to prompts using the tokenizer and calls
the `generate` method to generate the responses.
Args:
messages: A list of messages to generate responses for. Each
message is a list of dictionaries with 'role' and 'content'
keys.
sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it
is a single value, it is applied to every prompt. When it
is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
chat_template: The template to use for structuring the chat.
If not provided, the model's default chat template will be used.
add_generation_template: If True, adds a generation template
to each message.
Returns:
A list of `RequestOutput` objects containing the generated
responses in the same order as the input messages.
"""

tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config()

conversations, _ = parse_chat_messages(messages, model_config,
tokenizer)

prompts = apply_chat_template(
tokenizer,
conversations,
chat_template=chat_template,
add_generation_template=add_generation_template)

return self.generate(
prompts,
sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
)

@overload # LEGACY: single (prompt + optional token ids)
def encode(
self,
Expand Down Expand Up @@ -413,11 +475,13 @@ def encode(
) -> List[EmbeddingRequestOutput]:
...

@deprecate_kwargs("prompts",
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter "
"instead.")
@deprecate_kwargs(
"prompts",
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter "
"instead.",
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
)
def encode(
self,
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
Expand All @@ -443,7 +507,7 @@ def encode(
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.

Returns:
Expand Down Expand Up @@ -563,23 +627,24 @@ def _validate_and_add_requests(
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
)

def _add_request(
self,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[Union[List[LoRARequest],
LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
self,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(
request_id,
inputs,
params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
)

def _add_guided_processor(
self,
Expand Down Expand Up @@ -628,8 +693,8 @@ def _run_engine(
in_spd = total_in_toks / pbar.format_dict["elapsed"]
total_out_toks += sum(
len(stp.token_ids) for stp in output.outputs)
out_spd = total_out_toks / pbar.format_dict[
"elapsed"]
out_spd = (total_out_toks /
pbar.format_dict["elapsed"])
pbar.postfix = (
f"est. speed input: {in_spd:.2f} toks/s, "
f"output: {out_spd:.2f} toks/s")
Expand Down
Loading