Skip to content

Commit

Permalink
remove support for multiple chats
Browse files Browse the repository at this point in the history
  • Loading branch information
nunjunj committed Aug 8, 2024
1 parent a5ed382 commit c4ca2f3
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 88 deletions.
27 changes: 1 addition & 26 deletions examples/offline_inference_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ def print_outputs(outputs):

print("=" * 80)

# In this script, we demonstrate two ways to pass input to the chat method:
# In this script, we demonstrate how to pass input to the chat method:

# Conversation with a list of dictionaries
conversation = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
Expand All @@ -31,30 +30,6 @@ def print_outputs(outputs):
)
print_outputs(outputs)

# Multiple conversations
conversations = [
[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "What is dark matter?"},
],
[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "How are you?"},
{
"role": "assistant",
"content": "I'm an AI without feelings, but I'm here to help!",
},
{"role": "user", "content": "Tell me a joke."},
],
]

outputs = llm.chat(
conversations,
sampling_params=sampling_params,
use_tqdm=False,
)
print_outputs(outputs)

# A chat template can be optionally supplied.
# If not, the model will use its default chat template.

Expand Down
22 changes: 0 additions & 22 deletions tests/entrypoints/llm/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,25 +159,3 @@ def test_chat():
]
outputs = llm.chat(messages)
assert len(outputs) == 1

prompt2 = "Describe Bangkok in 150 words."
multiple_messages = [messages] + [[
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt2
},
]]
outputs = llm.chat(multiple_messages)
assert len(outputs) == len(multiple_messages)

sampling_params = [
SamplingParams(temperature=0.01, top_p=0.95),
SamplingParams(temperature=0.3, top_p=0.95),
]

outputs = llm.chat(multiple_messages, sampling_params=sampling_params)
assert len(outputs) == len(messages)
60 changes: 20 additions & 40 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from contextlib import contextmanager
from typing import (ClassVar, Dict, List, Optional, Sequence, Union, cast,
from typing import (ClassVar, List, Optional, Sequence, Union, cast,
overload)

from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, parse_chat_messages
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
parse_chat_messages,
apply_chat_template)
from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt,
parse_and_batch_prompt)
from vllm.logger import init_logger
Expand Down Expand Up @@ -360,13 +362,9 @@ def generate(

def chat(
self,
messages: Union[
List[ChatCompletionMessageParam],
List[List[ChatCompletionMessageParam]]
],
sampling_params: Optional[
Union[SamplingParams, List[SamplingParams]]
] = None,
messages: List[ChatCompletionMessageParam],
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None,
Expand Down Expand Up @@ -398,33 +396,17 @@ def chat(

tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config()

if isinstance(messages[0], dict):
conversations, _ = parse_chat_messages(
messages,
model_config,
tokenizer
)

prompts = tokenizer.apply_chat_template(
conversations,
tokenize=False,
add_generation_template=add_generation_template,
chat_template=chat_template,
)

elif isinstance(messages[0], list):
conversations, _ = parse_chat_messages(messages, model_config,
tokenizer)

prompts = apply_chat_template(
tokenizer,
conversations,
chat_template=chat_template,
add_generation_template=add_generation_template
)

prompts = [
tokenizer.apply_chat_template(
parse_chat_messages(message, model_config, tokenizer)[0],
tokenize=False,
add_generation_template=add_generation_template,
chat_template=chat_template,
)
for message in messages
]

return self.generate(
prompts,
sampling_params,
Expand All @@ -436,9 +418,8 @@ def chat(
def encode(
self,
prompts: str,
pooling_params: Optional[
Union[PoolingParams, Sequence[PoolingParams]]
] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
Expand Down Expand Up @@ -672,9 +653,8 @@ def _validate_and_add_requests(
self._add_request(
request_inputs,
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i]
if isinstance(lora_request, Sequence)
else lora_request,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
prompt_adapter_request=prompt_adapter_request,
)

Expand Down

0 comments on commit c4ca2f3

Please sign in to comment.