Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Add OpenAI API support for input_audio #11027

Merged
merged 12 commits into from
Dec 17, 2024
125 changes: 121 additions & 4 deletions tests/entrypoints/openai/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,61 @@ async def test_single_chat_session_audio_base64encoded(
assert message.content is not None and len(message.content) >= 0


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
async def test_single_chat_session_input_audio(
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
base64_encoded_audio: Dict[str, str]):
messages = [{
"role":
"user",
"content": [
{
"type": "input_audio",
"input_audio": {
"data": base64_encoded_audio[audio_url],
"format": "wav"
}
},
{
"type": "text",
"text": "What's happening in this audio?"
},
],
}]

# test single completion
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
logprobs=True,
top_logprobs=5)
assert len(chat_completion.choices) == 1

choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=202, total_tokens=212)

message = choice.message
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
messages.append({"role": "assistant", "content": message.content})

# test multi-turn dialogue
messages.append({"role": "user", "content": "express your result in json"})
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
Expand Down Expand Up @@ -211,11 +266,72 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
assert "".join(chunks) == output


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
model_name: str, audio_url: str,
base64_encoded_audio: Dict[str,
str]):
messages = [{
"role":
"user",
"content": [
{
"type": "input_audio",
"input_audio": {
"data": base64_encoded_audio[audio_url],
"format": "wav"
}
},
{
"type": "text",
"text": "What's happening in this audio?"
},
],
}]

# test single completion
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
temperature=0.0,
)
output = chat_completion.choices[0].message.content
stop_reason = chat_completion.choices[0].finish_reason

# test streaming
stream = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
temperature=0.0,
stream=True,
)
chunks: List[str] = []
finish_reason_count = 0
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.role:
assert delta.role == "assistant"
if delta.content:
chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
# finish reason should only return in last block
assert finish_reason_count == 1
assert chunk.choices[0].finish_reason == stop_reason
assert delta.content
assert "".join(chunks) == output


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
audio_url: str):
audio_url: str,
base64_encoded_audio: Dict[str, str]):

messages = [{
"role":
Expand All @@ -228,9 +344,10 @@ async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
}
},
{
"type": "audio_url",
"audio_url": {
"url": audio_url
"type": "input_audio",
"input_audio": {
"data": base64_encoded_audio[audio_url],
"format": "wav"
}
},
{
Expand Down
76 changes: 66 additions & 10 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# yapf conflicts with isort for this block
# yapf: disable
from openai.types.chat import (ChatCompletionAssistantMessageParam,
ChatCompletionContentPartImageParam)
ChatCompletionContentPartImageParam,
ChatCompletionContentPartInputAudioParam)
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
from openai.types.chat import (ChatCompletionContentPartRefusalParam,
Expand Down Expand Up @@ -92,6 +93,16 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
audio_url: Required[str]


class CustomChatCompletionContentInputAudioParam(TypedDict, total=False):
# Same as InputAudio type from https://github.com/openai/openai-python/blob/main/src/openai/types/chat/chat_completion_content_part_input_audio_param.py
data: Required[str]
"""Base64 encoded audio data."""

format: Required[Literal["wav", "mp3"]]
"""The format of the encoded audio data.
Currently supports "wav" and "mp3"."""


DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain audio_url.

Expand All @@ -105,6 +116,8 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):

ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
ChatCompletionContentPartInputAudioParam,
CustomChatCompletionContentInputAudioParam,
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentSimpleImageParam,
CustomChatCompletionContentSimpleAudioParam,
Expand Down Expand Up @@ -519,6 +532,10 @@ def parse_image(self, image_url: str) -> None:
def parse_audio(self, audio_url: str) -> None:
raise NotImplementedError

@abstractmethod
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
raise NotImplementedError

@abstractmethod
def parse_video(self, video_url: str) -> None:
raise NotImplementedError
Expand All @@ -545,6 +562,15 @@ def parse_audio(self, audio_url: str) -> None:
placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder)

def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
input_audio_data = input_audio.get("data","")
input_audio_format = input_audio.get("format","")
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}"
audio = get_and_parse_audio(audio_url)

placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder)

def parse_video(self, video_url: str) -> None:
video = get_and_parse_video(video_url)

Expand Down Expand Up @@ -574,6 +600,15 @@ def parse_audio(self, audio_url: str) -> None:
placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder)

def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
input_audio_data = input_audio.get("data","")
input_audio_format = input_audio.get("format","")
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}"
audio_coro = async_get_and_parse_audio(audio_url)

placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder)

def parse_video(self, video_url: str) -> None:
video = async_get_and_parse_video(video_url)

Expand Down Expand Up @@ -667,17 +702,22 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)

# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = {
MM_PARSER_MAP: Dict[str,
Callable[[ChatCompletionContentPartParam],
Union[str, Dict[str,str]]]] = {
"text":
lambda part: _TextParser(part).get("text", ""),
"image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", ""),
"audio_url":
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
"input_audio":
lambda part: _InputAudioParser(part).get("input_audio", {}),
"refusal":
lambda part: _RefusalParser(part).get("refusal", ""),
"video_url":
Expand All @@ -686,7 +726,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],


def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam) -> Tuple[str, str]:
part: ChatCompletionContentPartParam) -> Tuple[str,
Union[str, Dict[str, str]]]:
"""
Parses a given multi-modal content part based on its type.

Expand Down Expand Up @@ -717,6 +758,7 @@ def _parse_chat_message_content_mm_part(
return part_type, content

# Handle missing 'type' but provided direct URL fields.
# 'type' is required field by pydanic
kylehh marked this conversation as resolved.
Show resolved Hide resolved
if part_type is None:
if part.get("image_url") is not None:
image_params = cast(CustomChatCompletionContentSimpleImageParam,
Expand All @@ -726,6 +768,9 @@ def _parse_chat_message_content_mm_part(
audio_params = cast(CustomChatCompletionContentSimpleAudioParam,
part)
return "audio_url", audio_params.get("audio_url", "")
if part.get("input_audio") is not None:
input_audio_params = cast(Dict[str, str], part)
return "input_audio", input_audio_params
if part.get("video_url") is not None:
video_params = cast(CustomChatCompletionContentSimpleVideoParam,
part)
Expand All @@ -739,7 +784,7 @@ def _parse_chat_message_content_mm_part(


VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
"audio_url", "video_url")
"audio_url", "input_audio", "video_url")


def _parse_chat_message_content_parts(
Expand Down Expand Up @@ -795,7 +840,7 @@ def _parse_chat_message_content_part(
# Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)

# if part_type is text/refusal/image_url/audio_url/video_url but
# if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
# content is empty, log a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
logger.warning(
Expand All @@ -804,18 +849,30 @@ def _parse_chat_message_content_part(
return None

if part_type in ("text", "refusal"):
return {'type': 'text', 'text': content} if wrap_dicts else content
str_content = cast(str, content)
if wrap_dicts:
return {'type': 'text', 'text': str_content}
else:
return str_content

if part_type == "image_url":
mm_parser.parse_image(content)
str_content = cast(str, content)
mm_parser.parse_image(str_content)
return {'type': 'image'} if wrap_dicts else None

if part_type == "audio_url":
mm_parser.parse_audio(content)
str_content = cast(str, content)
mm_parser.parse_audio(str_content)
return {'type': 'audio'} if wrap_dicts else None

if part_type == "input_audio":
dict_content = cast(Dict[str, str], content)
mm_parser.parse_input_audio(dict_content)
return {'type': 'audio'} if wrap_dicts else None

if part_type == "video_url":
mm_parser.parse_video(content)
str_content = cast(str, content)
mm_parser.parse_video(str_content)
return {'type': 'video'} if wrap_dicts else None

raise NotImplementedError(f"Unknown part type: {part_type}")
Expand All @@ -840,7 +897,6 @@ def _parse_chat_message_content(
content = [
ChatCompletionContentPartTextParam(type="text", text=content)
]

result = _parse_chat_message_content_parts(
role,
content, # type: ignore
Expand Down
Loading