Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] Use vllm chat object #2659

Merged
merged 7 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python
#
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
from typing import Optional
from pydantic import Field
from vllm.entrypoints.openai.protocol import ChatCompletionRequest


class ChatProperties(ChatCompletionRequest):
"""
Chat input parameters for chat completions API.
See https://platform.openai.com/docs/api-reference/chat/create
"""

model: Optional[str] = Field(default=None, exclude=True) # Unused
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#!/usr/bin/env python
#
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
from typing import Dict, List, Optional, Union

from djl_python.chat_completions.vllm_chat_properties import ChatProperties
from djl_python.properties_manager.properties import Properties
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages)


def is_chat_completions_request(inputs: Dict) -> bool:
return "messages" in inputs


def parse_chat_completions_request_vllm(
input_map: Dict,
is_rolling_batch: bool,
rolling_batch,
tokenizer,
chat_template: Optional[str] = None,
image_token: Optional[str] = None,
configs: Properties = None,
is_mistral_tokenizer: bool = False,
):
# Chat completions can either be a rolling batch or no-batching .
if not (is_rolling_batch or configs.batch_size == 1):
raise ValueError(
"chat completions support is not currently available for dynamic batching. "
"You must enable rolling batch to use the chat completions format."
)

if not is_mistral_tokenizer and not hasattr(tokenizer,
"apply_chat_template"):
raise AttributeError(
f"Cannot provide chat completion for tokenizer: {tokenizer.__class__}, "
f"please ensure that your tokenizer supports chat templates.")

chat_params = ChatProperties(**input_map)
exclude = {"messages"}
param = chat_params.model_dump(exclude_none=True, exclude=exclude)

conversation, mm_data = parse_chat_messages(
chat_params.messages, rolling_batch.get_model_config(), tokenizer)

prompt_data: Union[str, List[int]]
if is_mistral_tokenizer:
text_inputs = apply_mistral_chat_template(
tokenizer,
messages=chat_params.messages,
chat_template=chat_template,
add_generation_prompt=True,
)
else:
text_inputs = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=True,
)

param["details"] = True # Enable details for chat completions
param[
"output_formatter"] = "jsonlines_chat" if chat_params.stream else "json_chat"

if mm_data:
param["mm_data"] = mm_data

# In the case of mistral, text_inputs = List[TokenIds], else = str
return text_inputs, param
30 changes: 22 additions & 8 deletions engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from djl_python import Input
from djl_python.chat_completions.chat_utils import is_chat_completions_request, parse_chat_completions_request
from djl_python.chat_completions.vllm_chat_utils import parse_chat_completions_request_vllm
from djl_python.encode_decode import decode
from djl_python.properties_manager.properties import is_rolling_batch_enabled
from djl_python.request import Request
Expand Down Expand Up @@ -140,14 +141,27 @@ def parse_text_inputs_params(request_input: TextInput, input_item: Input,
if configs is not None:
is_bedrock = configs.bedrock_compat
if is_chat_completions_request(input_map):
inputs, param = parse_chat_completions_request(
input_map,
kwargs.get("is_rolling_batch"),
tokenizer,
image_token=image_token,
configs=configs,
is_mistral_tokenizer=is_mistral_tokenizer,
)
rolling_batch = kwargs.get("rolling_batch")
if rolling_batch is not None and rolling_batch.use_vllm_chat_completions(
):
inputs, param = parse_chat_completions_request_vllm(
input_map,
kwargs.get("is_rolling_batch"),
rolling_batch,
tokenizer,
image_token=image_token,
configs=configs,
is_mistral_tokenizer=is_mistral_tokenizer,
)
else:
inputs, param = parse_chat_completions_request(
input_map,
kwargs.get("is_rolling_batch"),
tokenizer,
image_token=image_token,
configs=configs,
is_mistral_tokenizer=is_mistral_tokenizer,
)
elif is_bedrock:
inputs, param = parse_3p_request(input_map,
kwargs.get("is_rolling_batch"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,19 @@ def get_tokenizer(self):
return self.engine.preprocessor.tokenizer
return self.engine.preprocessor.tokenizer.tokenizer

def get_model_config(self):
# TODO: this is a hack right now to get the model config from the engine. We should expose this as
# an interface method and retrieve it from there after v12
return self.engine.preprocessor.model_config if not self.is_t5_model else None

def use_vllm_chat_completions(self):
return True

def get_huggingface_model_config(self):
# TODO: this is a hack right now to get the model config from the engine. We should expose this as
# an interface method and retrieve it from there after v12
return self.engine.preprocessor.model_config.hf_config if not self.is_t5_model else None

def get_huggingface_model_config(self):
# TODO: this is a hack right now to get the model config from the engine. We should expose this as
# an interface method and retrieve it from there after v12
Expand Down
12 changes: 12 additions & 0 deletions engines/python/setup/djl_python/rolling_batch/rolling_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,25 @@ def get_tokenizer(self):
"""
raise RuntimeError("get_tokenizer function not supported")

def get_model_config(self):
"""
:return: the model config if available
"""
raise RuntimeError("get_model_config must be implemented by subclass")

def get_huggingface_model_config(self):
"""
:return: the huggingface pretrained config if available
"""
raise RuntimeError(
"get_huggingface_model_config must be implemented by subclass")

def use_vllm_chat_completions(self):
"""
:return: whether to use the vllm chat completions.
"""
return False

@abstractmethod
def inference(self, new_requests: List[Request]) -> List:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,18 +304,9 @@ def get_engine_args_from_config(config: VllmRbProperties) -> EngineArgs:
)


def get_multi_modal_data(request: Request) -> Optional[dict]:
parameters = request.parameters
images = parameters.pop("images", None)
multi_modal_data = None
if images:
multi_modal_data = {"image": images}
return multi_modal_data


def get_prompt_inputs(request: Request):
text_prompt = request.request_input.input_text
multi_modal_data = get_multi_modal_data(request)
multi_modal_data = request.parameters.pop("mm_data", None)
# TODO: In chat cases, we need to apply the chat template to the messages object to get a string
# In both HuggingFace and mistral cases, that process can also yield token-ids directly
# that we may want to consider passing directly to the engine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,15 @@ def __init__(self, model_id_or_path: str, properties: dict,
def get_tokenizer(self):
return self.engine.tokenizer.tokenizer

def get_model_config(self):
return self.engine.model_config

def get_huggingface_model_config(self):
return self.engine.model_config.hf_config

def use_vllm_chat_completions(self):
return True

def reset(self) -> None:
"""
Aborts all requests
Expand Down
2 changes: 1 addition & 1 deletion engines/python/setup/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def run(self):
test_requirements = [
'numpy<2', 'requests', 'Pillow', 'transformers', 'torch', 'einops',
'accelerate', 'sentencepiece', 'protobuf', "peft", 'yapf',
'pydantic>=2.0', "objgraph"
'pydantic>=2.0', "objgraph", "vllm==0.6.3.post1"
]

setup(name='djl_python',
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1918,7 +1918,7 @@ def get_multimodal_prompt(batch_size):
"messages": messages,
"temperature": 0.9,
"top_p": 0.6,
"max_new_tokens": 512,
"max_tokens": 512,
}


Expand Down
Loading