Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/zai/api_resource/audio/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING, Mapping, Optional, cast

import httpx
from httpx import stream

from zai.core import (
NOT_GIVEN,
Expand All @@ -23,6 +22,8 @@
from zai.types.sensitive_word_check import SensitiveWordCheckRequest

from .transcriptions import Transcriptions
from zai.core._streaming import StreamResponse
from zai.types.audio import AudioSpeechChunk

if TYPE_CHECKING:
from zai._client import ZaiClient
Expand Down Expand Up @@ -60,7 +61,7 @@ def speech(
speed: float | None = 1.0,
volume: float | None = 1.0,
stream: bool | None = False
) -> HttpxBinaryResponseContent:
) -> HttpxBinaryResponseContent | StreamResponse[AudioSpeechChunk]:
"""
Generate speech audio from text input

Expand All @@ -83,7 +84,6 @@ def speech(
'voice': voice,
'response_format': response_format,
'encode_format': encode_format,
'sensitive_word_check': sensitive_word_check,
'request_id': request_id,
'user_id': user_id,
'speed': speed,
Expand All @@ -96,6 +96,8 @@ def speech(
body=maybe_transform(body, AudioSpeechParams),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=HttpxBinaryResponseContent,
stream=stream or False,
stream_cls=StreamResponse[AudioSpeechChunk]
)

def customization(
Expand Down
3 changes: 2 additions & 1 deletion src/zai/types/audio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .audio_customization_param import AudioCustomizationParam
from .audio_speech_chunk import AudioSpeechChunk
from .audio_speech_params import AudioSpeechParams
from .transcriptions_create_param import TranscriptionsParam

__all__ = ['AudioSpeechParams', 'AudioCustomizationParam', 'TranscriptionsParam']
__all__ = ['AudioSpeechParams', 'AudioCustomizationParam', 'TranscriptionsParam', 'AudioSpeechChunk']
32 changes: 32 additions & 0 deletions src/zai/types/audio/audio_speech_chunk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import List, Optional, Dict, Any

from ...core import BaseModel

__all__ = [
"AudioSpeechChunk",
"AudioError",
"AudioSpeechChoice",
"AudioSpeechDelta"
]


class AudioSpeechDelta(BaseModel):
content: Optional[str] = None
role: Optional[str] = None


class AudioSpeechChoice(BaseModel):
delta: AudioSpeechDelta
finish_reason: Optional[str] = None
index: int

class AudioError(BaseModel):
code: Optional[str] = None
message: Optional[str] = None


class AudioSpeechChunk(BaseModel):
choices: List[AudioSpeechChoice]
request_id: Optional[str] = None
created: Optional[int] = None
error: Optional[AudioError] = None
4 changes: 4 additions & 0 deletions src/zai/types/audio/audio_speech_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,7 @@ class AudioSpeechParams(TypedDict, total=False):
sensitive_word_check: Optional[SensitiveWordCheckRequest]
request_id: str
user_id: str
encode_format: str
speed: float
volume: float
stream: bool
14 changes: 11 additions & 3 deletions tests/integration_tests/test_audio.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import base64
import logging
import logging.config
from pathlib import Path

import zai
from zai import ZaiClient


def test_audio_speech(logging_conf):
logging.config.dictConfig(logging_conf) # type: ignore
client = ZaiClient() # Fill in your own API Key
Expand All @@ -17,11 +17,19 @@ def test_audio_speech(logging_conf):
voice='female',
response_format='pcm',
encode_format='base64',
stream=False,
stream=True,
speed=1.0,
volume=1.0,
)
response.stream_to_file(speech_file_path)
with open("output.pcm", "wb") as f:
for item in response:
choice = item.choices[0]
index = choice.index
finish_reason = choice.finish_reason
if choice.delta is None:
break
audio_delta = choice.delta.content
f.write(base64.b64decode(audio_delta))

except zai.core._errors.APIRequestFailedError as err:
print(err)
Expand Down