Skip to content

Commit

Permalink
[python] Use vllm chat object (#2659)
Browse files Browse the repository at this point in the history
Co-authored-by: Siddharth Venkatesan <[email protected]>
  • Loading branch information
xyang16 and siddvenk authored Jan 17, 2025
1 parent d717cae commit ab53670
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python
#
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
from typing import Optional
from pydantic import Field
from vllm.entrypoints.openai.protocol import ChatCompletionRequest


class ChatProperties(ChatCompletionRequest):
"""
Chat input parameters for chat completions API.
See https://platform.openai.com/docs/api-reference/chat/create
"""

model: Optional[str] = Field(default=None, exclude=True) # Unused
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#!/usr/bin/env python
#
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
from typing import Dict, List, Optional, Union

from djl_python.chat_completions.vllm_chat_properties import ChatProperties
from djl_python.properties_manager.properties import Properties
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages)


def is_chat_completions_request(inputs: Dict) -> bool:
return "messages" in inputs


def parse_chat_completions_request_vllm(
input_map: Dict,
is_rolling_batch: bool,
rolling_batch,
tokenizer,
chat_template: Optional[str] = None,
image_token: Optional[str] = None,
configs: Properties = None,
is_mistral_tokenizer: bool = False,
):
# Chat completions can either be a rolling batch or no-batching .
if not (is_rolling_batch or configs.batch_size == 1):
raise ValueError(
"chat completions support is not currently available for dynamic batching. "
"You must enable rolling batch to use the chat completions format."
)

if not is_mistral_tokenizer and not hasattr(tokenizer,
"apply_chat_template"):
raise AttributeError(
f"Cannot provide chat completion for tokenizer: {tokenizer.__class__}, "
f"please ensure that your tokenizer supports chat templates.")

chat_params = ChatProperties(**input_map)
exclude = {"messages"}
param = chat_params.model_dump(exclude_none=True, exclude=exclude)

conversation, mm_data = parse_chat_messages(
chat_params.messages, rolling_batch.get_model_config(), tokenizer)

prompt_data: Union[str, List[int]]
if is_mistral_tokenizer:
text_inputs = apply_mistral_chat_template(
tokenizer,
messages=chat_params.messages,
chat_template=chat_template,
add_generation_prompt=True,
)
else:
text_inputs = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=True,
)

param["details"] = True # Enable details for chat completions
param[
"output_formatter"] = "jsonlines_chat" if chat_params.stream else "json_chat"

if mm_data:
param["mm_data"] = mm_data

# In the case of mistral, text_inputs = List[TokenIds], else = str
return text_inputs, param
30 changes: 22 additions & 8 deletions engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from djl_python import Input
from djl_python.chat_completions.chat_utils import is_chat_completions_request, parse_chat_completions_request
from djl_python.chat_completions.vllm_chat_utils import parse_chat_completions_request_vllm
from djl_python.encode_decode import decode
from djl_python.properties_manager.properties import is_rolling_batch_enabled
from djl_python.request import Request
Expand Down Expand Up @@ -140,14 +141,27 @@ def parse_text_inputs_params(request_input: TextInput, input_item: Input,
if configs is not None:
is_bedrock = configs.bedrock_compat
if is_chat_completions_request(input_map):
inputs, param = parse_chat_completions_request(
input_map,
kwargs.get("is_rolling_batch"),
tokenizer,
image_token=image_token,
configs=configs,
is_mistral_tokenizer=is_mistral_tokenizer,
)
rolling_batch = kwargs.get("rolling_batch")
if rolling_batch is not None and rolling_batch.use_vllm_chat_completions(
):
inputs, param = parse_chat_completions_request_vllm(
input_map,
kwargs.get("is_rolling_batch"),
rolling_batch,
tokenizer,
image_token=image_token,
configs=configs,
is_mistral_tokenizer=is_mistral_tokenizer,
)
else:
inputs, param = parse_chat_completions_request(
input_map,
kwargs.get("is_rolling_batch"),
tokenizer,
image_token=image_token,
configs=configs,
is_mistral_tokenizer=is_mistral_tokenizer,
)
elif is_bedrock:
inputs, param = parse_3p_request(input_map,
kwargs.get("is_rolling_batch"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,19 @@ def get_tokenizer(self):
return self.engine.preprocessor.tokenizer
return self.engine.preprocessor.tokenizer.tokenizer

def get_model_config(self):
# TODO: this is a hack right now to get the model config from the engine. We should expose this as
# an interface method and retrieve it from there after v12
return self.engine.preprocessor.model_config if not self.is_t5_model else None

def use_vllm_chat_completions(self):
return True

def get_huggingface_model_config(self):
# TODO: this is a hack right now to get the model config from the engine. We should expose this as
# an interface method and retrieve it from there after v12
return self.engine.preprocessor.model_config.hf_config if not self.is_t5_model else None

def get_huggingface_model_config(self):
# TODO: this is a hack right now to get the model config from the engine. We should expose this as
# an interface method and retrieve it from there after v12
Expand Down
12 changes: 12 additions & 0 deletions engines/python/setup/djl_python/rolling_batch/rolling_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,25 @@ def get_tokenizer(self):
"""
raise RuntimeError("get_tokenizer function not supported")

def get_model_config(self):
"""
:return: the model config if available
"""
raise RuntimeError("get_model_config must be implemented by subclass")

def get_huggingface_model_config(self):
"""
:return: the huggingface pretrained config if available
"""
raise RuntimeError(
"get_huggingface_model_config must be implemented by subclass")

def use_vllm_chat_completions(self):
"""
:return: whether to use the vllm chat completions.
"""
return False

@abstractmethod
def inference(self, new_requests: List[Request]) -> List:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,18 +304,9 @@ def get_engine_args_from_config(config: VllmRbProperties) -> EngineArgs:
)


def get_multi_modal_data(request: Request) -> Optional[dict]:
parameters = request.parameters
images = parameters.pop("images", None)
multi_modal_data = None
if images:
multi_modal_data = {"image": images}
return multi_modal_data


def get_prompt_inputs(request: Request):
text_prompt = request.request_input.input_text
multi_modal_data = get_multi_modal_data(request)
multi_modal_data = request.parameters.pop("mm_data", None)
# TODO: In chat cases, we need to apply the chat template to the messages object to get a string
# In both HuggingFace and mistral cases, that process can also yield token-ids directly
# that we may want to consider passing directly to the engine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,15 @@ def __init__(self, model_id_or_path: str, properties: dict,
def get_tokenizer(self):
return self.engine.tokenizer.tokenizer

def get_model_config(self):
return self.engine.model_config

def get_huggingface_model_config(self):
return self.engine.model_config.hf_config

def use_vllm_chat_completions(self):
return True

def reset(self) -> None:
"""
Aborts all requests
Expand Down
2 changes: 1 addition & 1 deletion engines/python/setup/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def run(self):
test_requirements = [
'numpy<2', 'requests', 'Pillow', 'transformers', 'torch', 'einops',
'accelerate', 'sentencepiece', 'protobuf', "peft", 'yapf',
'pydantic>=2.0', "objgraph"
'pydantic>=2.0', "objgraph", "vllm==0.6.3.post1"
]

setup(name='djl_python',
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1918,7 +1918,7 @@ def get_multimodal_prompt(batch_size):
"messages": messages,
"temperature": 0.9,
"top_p": 0.6,
"max_new_tokens": 512,
"max_tokens": 512,
}


Expand Down

0 comments on commit ab53670

Please sign in to comment.