Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AzureOpenAiTextToSpeechDriver #1150

Merged
merged 2 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/docs-integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ jobs:
AZURE_OPENAI_API_KEY_2: ${{ secrets.INTEG_AZURE_OPENAI_API_KEY_2 }}
AZURE_OPENAI_ENDPOINT_3: ${{ secrets.INTEG_AZURE_OPENAI_ENDPOINT_3 }}
AZURE_OPENAI_API_KEY_3: ${{ secrets.INTEG_AZURE_OPENAI_API_KEY_3 }}
AZURE_OPENAI_ENDPOINT_4: ${{ secrets.INTEG_AZURE_OPENAI_ENDPOINT_4 }}
AZURE_OPENAI_API_KEY_4: ${{ secrets.INTEG_AZURE_OPENAI_API_KEY_4 }}
AZURE_OPENAI_35_TURBO_16K_DEPLOYMENT_ID: ${{ secrets.INTEG_OPENAI_35_TURBO_16K_DEPLOYMENT_ID }}
AZURE_OPENAI_35_TURBO_DEPLOYMENT_ID: ${{ secrets.INTEG_OPENAI_35_TURBO_DEPLOYMENT_ID }}
AZURE_OPENAI_DAVINCI_DEPLOYMENT_ID: ${{ secrets.INTEG_OPENAI_DAVINCI_DEPLOYMENT_ID }}
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Parameter `meta: dict` on `BaseEvent`.
- `AzureOpenAiTextToSpeechDriver`.

### Changed
- **BREAKING**: Drivers, Loaders, and Engines now raise exceptions rather than returning `ErrorArtifacts`.
Expand Down
20 changes: 20 additions & 0 deletions docs/griptape-framework/drivers/src/text_to_speech_drivers_3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os

from griptape.drivers import AzureOpenAiTextToSpeechDriver
from griptape.engines import TextToSpeechEngine
from griptape.structures import Agent
from griptape.tools.text_to_speech.tool import TextToSpeechTool

driver = AzureOpenAiTextToSpeechDriver(
api_key=os.environ["AZURE_OPENAI_API_KEY_4"],
model="tts",
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_4"],
)

tool = TextToSpeechTool(
engine=TextToSpeechEngine(
text_to_speech_driver=driver,
),
)

Agent(tools=[tool]).run("Generate audio from this text: 'Hello, world!'")
8 changes: 8 additions & 0 deletions docs/griptape-framework/drivers/text-to-speech-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,11 @@ The [OpenAI Text to Speech Driver](../../reference/griptape/drivers/text_to_spee
```python
--8<-- "docs/griptape-framework/drivers/src/text_to_speech_drivers_2.py"
```

## Azure OpenAI

The [Azure OpenAI Text to Speech Driver](../../reference/griptape/drivers/text_to_speech/azure_openai_text_to_speech_driver.md) provides support for text-to-speech models hosted in your Azure OpenAI instance. This Driver supports configurations specific to OpenAI, like voice selection and output format.

```python
--8<-- "docs/griptape-framework/drivers/src/text_to_speech_drivers_3.py"
```
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
from .text_to_speech.dummy_text_to_speech_driver import DummyTextToSpeechDriver
from .text_to_speech.elevenlabs_text_to_speech_driver import ElevenLabsTextToSpeechDriver
from .text_to_speech.openai_text_to_speech_driver import OpenAiTextToSpeechDriver
from .text_to_speech.azure_openai_text_to_speech_driver import AzureOpenAiTextToSpeechDriver

from .structure_run.base_structure_run_driver import BaseStructureRunDriver
from .structure_run.griptape_cloud_structure_run_driver import GriptapeCloudStructureRunDriver
Expand Down Expand Up @@ -227,6 +228,7 @@
"DummyTextToSpeechDriver",
"ElevenLabsTextToSpeechDriver",
"OpenAiTextToSpeechDriver",
"AzureOpenAiTextToSpeechDriver",
"BaseStructureRunDriver",
"GriptapeCloudStructureRunDriver",
"LocalStructureRunDriver",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

from typing import Callable, Optional

import openai
from attrs import Factory, define, field

from griptape.drivers import OpenAiTextToSpeechDriver


@define
class AzureOpenAiTextToSpeechDriver(OpenAiTextToSpeechDriver):
"""Azure OpenAi Text to Speech Driver.

Attributes:
azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name.
azure_endpoint: An Azure OpenAi endpoint.
azure_ad_token: An optional Azure Active Directory token.
azure_ad_token_provider: An optional Azure Active Directory token provider.
api_version: An Azure OpenAi API version.
client: An `openai.AzureOpenAI` client.
"""

model: str = field(default="tts", kw_only=True, metadata={"serializable": True})
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default model is tts not tts-1

azure_deployment: str = field(
kw_only=True,
default=Factory(lambda self: self.model, takes_self=True),
metadata={"serializable": True},
)
azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
azure_ad_token_provider: Optional[Callable[[], str]] = field(
kw_only=True,
default=None,
metadata={"serializable": False},
)
api_version: str = field(default="2024-07-01-preview", kw_only=True, metadata={"serializable": True})
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need the latest preview API version, not GA yet.

client: openai.AzureOpenAI = field(
default=Factory(
lambda self: openai.AzureOpenAI(
organization=self.organization,
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
),
takes_self=True,
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from unittest.mock import Mock

import pytest

from griptape.drivers import AzureOpenAiTextToSpeechDriver


class TestAzureOpenAiTextToSpeechDriver:
@pytest.fixture()
def mock_speech_create(self, mocker):
mock_speech_create = mocker.patch("openai.AzureOpenAI").return_value.audio.speech.create
mock_function = Mock(arguments='{"foo": "bar"}', id="mock-id")
mock_function.name = "MockTool_test"
mock_speech_create.return_value = Mock(
content=b"speech",
)

return mock_speech_create

def test_init(self):
assert AzureOpenAiTextToSpeechDriver(azure_endpoint="foobar", azure_deployment="foobar")
assert AzureOpenAiTextToSpeechDriver(azure_endpoint="foobar").azure_deployment == "tts"

def test_run_text_to_audio(self, mock_speech_create):
driver = AzureOpenAiTextToSpeechDriver(azure_endpoint="foobar")
output = driver.run_text_to_audio(["foo", "bar"])
mock_speech_create.assert_called_once_with(
input="foo. bar",
model=driver.model,
response_format=driver.format,
voice=driver.voice,
)
assert output.value == b"speech"
Loading