Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Add OpenAI API support for input_audio #11027

Merged
merged 12 commits into from
Dec 17, 2024
6 changes: 3 additions & 3 deletions docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ We currently support the following OpenAI APIs:
- [Chat Completions API](https://platform.openai.com/docs/api-reference/chat)
- [Vision](https://platform.openai.com/docs/guides/vision)-related parameters are supported; see [Multimodal Inputs](../usage/multimodal_inputs.rst).
- *Note: `image_url.detail` parameter is not supported.*
- We also support `audio_url` content type for audio files.
- Refer to [vllm.entrypoints.chat_utils](https://github.com/vllm-project/vllm/tree/main/vllm/entrypoints/chat_utils.py) for the exact schema.
- *TODO: Support `input_audio` content type as defined [here](https://github.com/openai/openai-python/blob/v1.52.2/src/openai/types/chat/chat_completion_content_part_input_audio_param.py).*
- We support two audio content types.
- Support `input_audio` content type as defined [here](https://platform.openai.com/docs/guides/audio?audio-generation-quickstart-example=audio-in).
- Support `audio_url` content type for audio files. Refer to [here](https://github.com/vllm-project/vllm/tree/main/vllm/entrypoints/chat_utils.py#L51) for the exact schema.
- *Note: `parallel_tool_calls` and `user` parameters are ignored.*
- [Embeddings API](https://platform.openai.com/docs/api-reference/embeddings)
- Instead of `inputs`, you can pass in a list of `messages` (same schema as Chat Completions API),
Expand Down
26 changes: 26 additions & 0 deletions examples/openai_chat_completion_client_for_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,32 @@ def run_audio() -> None:
result = chat_completion_from_base64.choices[0].message.content
print("Chat completion output from base64 encoded audio:", result)

chat_completion_from_base64 = client.chat.completions.create(
messages=[{
"role":
"user",
"content": [
{
"type": "text",
"text": "What's in this audio?"
},
{
"type": "input_audio",
"input_audio": {
# Any format supported by librosa is supported
"data": audio_base64,
"format": "wav"
},
},
],
}],
model=model,
max_completion_tokens=64,
)

result = chat_completion_from_base64.choices[0].message.content
print("Chat completion output from input audio:", result)


example_function_map = {
"text-only": run_text_only,
Expand Down
125 changes: 121 additions & 4 deletions tests/entrypoints/openai/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,61 @@ async def test_single_chat_session_audio_base64encoded(
assert message.content is not None and len(message.content) >= 0


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
async def test_single_chat_session_input_audio(
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
base64_encoded_audio: Dict[str, str]):
messages = [{
"role":
"user",
"content": [
{
"type": "input_audio",
"input_audio": {
"data": base64_encoded_audio[audio_url],
"format": "wav"
}
},
{
"type": "text",
"text": "What's happening in this audio?"
},
],
}]

# test single completion
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
logprobs=True,
top_logprobs=5)
assert len(chat_completion.choices) == 1

choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=202, total_tokens=212)

message = choice.message
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
messages.append({"role": "assistant", "content": message.content})

# test multi-turn dialogue
messages.append({"role": "user", "content": "express your result in json"})
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
Expand Down Expand Up @@ -211,11 +266,72 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
assert "".join(chunks) == output


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
model_name: str, audio_url: str,
base64_encoded_audio: Dict[str,
str]):
messages = [{
"role":
"user",
"content": [
{
"type": "input_audio",
"input_audio": {
"data": base64_encoded_audio[audio_url],
"format": "wav"
}
},
{
"type": "text",
"text": "What's happening in this audio?"
},
],
}]

# test single completion
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
temperature=0.0,
)
output = chat_completion.choices[0].message.content
stop_reason = chat_completion.choices[0].finish_reason

# test streaming
stream = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
temperature=0.0,
stream=True,
)
chunks: List[str] = []
finish_reason_count = 0
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.role:
assert delta.role == "assistant"
if delta.content:
chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
# finish reason should only return in last block
assert finish_reason_count == 1
assert chunk.choices[0].finish_reason == stop_reason
assert delta.content
assert "".join(chunks) == output


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
audio_url: str):
audio_url: str,
base64_encoded_audio: Dict[str, str]):

messages = [{
"role":
Expand All @@ -228,9 +344,10 @@ async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
}
},
{
"type": "audio_url",
"audio_url": {
"url": audio_url
"type": "input_audio",
"input_audio": {
"data": base64_encoded_audio[audio_url],
"format": "wav"
}
},
{
Expand Down
65 changes: 55 additions & 10 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# yapf conflicts with isort for this block
# yapf: disable
from openai.types.chat import (ChatCompletionAssistantMessageParam,
ChatCompletionContentPartImageParam)
ChatCompletionContentPartImageParam,
ChatCompletionContentPartInputAudioParam)
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
from openai.types.chat import (ChatCompletionContentPartRefusalParam,
Expand Down Expand Up @@ -105,6 +106,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):

ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
ChatCompletionContentPartInputAudioParam,
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentSimpleImageParam,
CustomChatCompletionContentSimpleAudioParam,
Expand Down Expand Up @@ -519,6 +521,10 @@ def parse_image(self, image_url: str) -> None:
def parse_audio(self, audio_url: str) -> None:
raise NotImplementedError

@abstractmethod
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
raise NotImplementedError

@abstractmethod
def parse_video(self, video_url: str) -> None:
raise NotImplementedError
Expand All @@ -545,6 +551,15 @@ def parse_audio(self, audio_url: str) -> None:
placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder)

def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
input_audio_data = input_audio.get("data","")
input_audio_format = input_audio.get("format","")
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}"
audio = get_and_parse_audio(audio_url)

placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder)

def parse_video(self, video_url: str) -> None:
video = get_and_parse_video(video_url)

Expand Down Expand Up @@ -574,6 +589,15 @@ def parse_audio(self, audio_url: str) -> None:
placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder)

def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
input_audio_data = input_audio.get("data","")
input_audio_format = input_audio.get("format","")
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}"
audio_coro = async_get_and_parse_audio(audio_url)

placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder)

def parse_video(self, video_url: str) -> None:
video = async_get_and_parse_video(video_url)

Expand Down Expand Up @@ -667,17 +691,22 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)

# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = {
MM_PARSER_MAP: Dict[str,
Callable[[ChatCompletionContentPartParam],
Union[str, Dict[str,str]]]] = {
"text":
lambda part: _TextParser(part).get("text", ""),
"image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", ""),
"audio_url":
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
"input_audio":
lambda part: _InputAudioParser(part).get("input_audio", {}),
"refusal":
lambda part: _RefusalParser(part).get("refusal", ""),
"video_url":
Expand All @@ -686,7 +715,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],


def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam) -> Tuple[str, str]:
part: ChatCompletionContentPartParam) -> Tuple[str,
Union[str, Dict[str, str]]]:
"""
Parses a given multi-modal content part based on its type.

Expand Down Expand Up @@ -717,6 +747,7 @@ def _parse_chat_message_content_mm_part(
return part_type, content

# Handle missing 'type' but provided direct URL fields.
# 'type' is required field by pydantic
if part_type is None:
if part.get("image_url") is not None:
image_params = cast(CustomChatCompletionContentSimpleImageParam,
Expand All @@ -726,6 +757,9 @@ def _parse_chat_message_content_mm_part(
audio_params = cast(CustomChatCompletionContentSimpleAudioParam,
part)
return "audio_url", audio_params.get("audio_url", "")
if part.get("input_audio") is not None:
input_audio_params = cast(Dict[str, str], part)
return "input_audio", input_audio_params
if part.get("video_url") is not None:
video_params = cast(CustomChatCompletionContentSimpleVideoParam,
part)
Expand All @@ -739,7 +773,7 @@ def _parse_chat_message_content_mm_part(


VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
"audio_url", "video_url")
"audio_url", "input_audio", "video_url")


def _parse_chat_message_content_parts(
Expand Down Expand Up @@ -795,7 +829,7 @@ def _parse_chat_message_content_part(
# Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)

# if part_type is text/refusal/image_url/audio_url/video_url but
# if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
# content is empty, log a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
logger.warning(
Expand All @@ -804,18 +838,30 @@ def _parse_chat_message_content_part(
return None

if part_type in ("text", "refusal"):
return {'type': 'text', 'text': content} if wrap_dicts else content
str_content = cast(str, content)
if wrap_dicts:
return {'type': 'text', 'text': str_content}
else:
return str_content

if part_type == "image_url":
mm_parser.parse_image(content)
str_content = cast(str, content)
mm_parser.parse_image(str_content)
return {'type': 'image'} if wrap_dicts else None

if part_type == "audio_url":
mm_parser.parse_audio(content)
str_content = cast(str, content)
mm_parser.parse_audio(str_content)
return {'type': 'audio'} if wrap_dicts else None

if part_type == "input_audio":
dict_content = cast(Dict[str, str], content)
mm_parser.parse_input_audio(dict_content)
return {'type': 'audio'} if wrap_dicts else None

if part_type == "video_url":
mm_parser.parse_video(content)
str_content = cast(str, content)
mm_parser.parse_video(str_content)
return {'type': 'video'} if wrap_dicts else None

raise NotImplementedError(f"Unknown part type: {part_type}")
Expand All @@ -840,7 +886,6 @@ def _parse_chat_message_content(
content = [
ChatCompletionContentPartTextParam(type="text", text=content)
]

result = _parse_chat_message_content_parts(
role,
content, # type: ignore
Expand Down
Loading