From 5366ef7e11158af41ab406e4ac29860439057276 Mon Sep 17 00:00:00 2001 From: matt Date: Fri, 6 Sep 2024 10:25:37 -0500 Subject: [PATCH] Add `AzureOpenAiTextToSpeechDriver` --- .github/workflows/docs-integration-tests.yml | 2 + CHANGELOG.md | 1 + .../drivers/src/text_to_speech_drivers_3.py | 20 ++++++++ .../drivers/text-to-speech-drivers.md | 8 +++ griptape/drivers/__init__.py | 2 + .../azure_openai_text_to_speech_driver.py | 51 +++++++++++++++++++ ...test_azure_openai_text_to_speech_driver.py | 33 ++++++++++++ 7 files changed, 117 insertions(+) create mode 100644 docs/griptape-framework/drivers/src/text_to_speech_drivers_3.py create mode 100644 griptape/drivers/text_to_speech/azure_openai_text_to_speech_driver.py create mode 100644 tests/unit/drivers/text_to_speech/test_azure_openai_text_to_speech_driver.py diff --git a/.github/workflows/docs-integration-tests.yml b/.github/workflows/docs-integration-tests.yml index 61111b20b..77dab4a5b 100644 --- a/.github/workflows/docs-integration-tests.yml +++ b/.github/workflows/docs-integration-tests.yml @@ -85,6 +85,8 @@ jobs: AZURE_OPENAI_API_KEY_2: ${{ secrets.INTEG_AZURE_OPENAI_API_KEY_2 }} AZURE_OPENAI_ENDPOINT_3: ${{ secrets.INTEG_AZURE_OPENAI_ENDPOINT_3 }} AZURE_OPENAI_API_KEY_3: ${{ secrets.INTEG_AZURE_OPENAI_API_KEY_3 }} + AZURE_OPENAI_ENDPOINT_4: ${{ secrets.INTEG_AZURE_OPENAI_ENDPOINT_4 }} + AZURE_OPENAI_API_KEY_4: ${{ secrets.INTEG_AZURE_OPENAI_API_KEY_4 }} AZURE_OPENAI_35_TURBO_16K_DEPLOYMENT_ID: ${{ secrets.INTEG_OPENAI_35_TURBO_16K_DEPLOYMENT_ID }} AZURE_OPENAI_35_TURBO_DEPLOYMENT_ID: ${{ secrets.INTEG_OPENAI_35_TURBO_DEPLOYMENT_ID }} AZURE_OPENAI_DAVINCI_DEPLOYMENT_ID: ${{ secrets.INTEG_OPENAI_DAVINCI_DEPLOYMENT_ID }} diff --git a/CHANGELOG.md b/CHANGELOG.md index e7d833612..a71401cf9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Parameter `meta: dict` on `BaseEvent`. +- `AzureOpenAiTextToSpeechDriver`. ### Changed - **BREAKING**: Drivers, Loaders, and Engines now raise exceptions rather than returning `ErrorArtifacts`. diff --git a/docs/griptape-framework/drivers/src/text_to_speech_drivers_3.py b/docs/griptape-framework/drivers/src/text_to_speech_drivers_3.py new file mode 100644 index 000000000..87add5498 --- /dev/null +++ b/docs/griptape-framework/drivers/src/text_to_speech_drivers_3.py @@ -0,0 +1,20 @@ +import os + +from griptape.drivers import AzureOpenAiTextToSpeechDriver +from griptape.engines import TextToSpeechEngine +from griptape.structures import Agent +from griptape.tools.text_to_speech.tool import TextToSpeechTool + +driver = AzureOpenAiTextToSpeechDriver( + api_key=os.environ["AZURE_OPENAI_API_KEY_4"], + model="tts", + azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_4"], +) + +tool = TextToSpeechTool( + engine=TextToSpeechEngine( + text_to_speech_driver=driver, + ), +) + +Agent(tools=[tool]).run("Generate audio from this text: 'Hello, world!'") diff --git a/docs/griptape-framework/drivers/text-to-speech-drivers.md b/docs/griptape-framework/drivers/text-to-speech-drivers.md index c5455914e..a6fb955e6 100644 --- a/docs/griptape-framework/drivers/text-to-speech-drivers.md +++ b/docs/griptape-framework/drivers/text-to-speech-drivers.md @@ -29,3 +29,11 @@ The [OpenAI Text to Speech Driver](../../reference/griptape/drivers/text_to_spee ```python --8<-- "docs/griptape-framework/drivers/src/text_to_speech_drivers_2.py" ``` + +## Azure OpenAI + +The [Azure OpenAI Text to Speech Driver](../../reference/griptape/drivers/text_to_speech/azure_openai_text_to_speech_driver.md) provides support for text-to-speech models hosted in your Azure OpenAI instance. This Driver supports configurations specific to OpenAI, like voice selection and output format. + +```python +--8<-- "docs/griptape-framework/drivers/src/text_to_speech_drivers_3.py" +``` diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index f19ec7d10..7d2de3552 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -118,6 +118,7 @@ from .text_to_speech.dummy_text_to_speech_driver import DummyTextToSpeechDriver from .text_to_speech.elevenlabs_text_to_speech_driver import ElevenLabsTextToSpeechDriver from .text_to_speech.openai_text_to_speech_driver import OpenAiTextToSpeechDriver +from .text_to_speech.azure_openai_text_to_speech_driver import AzureOpenAiTextToSpeechDriver from .structure_run.base_structure_run_driver import BaseStructureRunDriver from .structure_run.griptape_cloud_structure_run_driver import GriptapeCloudStructureRunDriver @@ -227,6 +228,7 @@ "DummyTextToSpeechDriver", "ElevenLabsTextToSpeechDriver", "OpenAiTextToSpeechDriver", + "AzureOpenAiTextToSpeechDriver", "BaseStructureRunDriver", "GriptapeCloudStructureRunDriver", "LocalStructureRunDriver", diff --git a/griptape/drivers/text_to_speech/azure_openai_text_to_speech_driver.py b/griptape/drivers/text_to_speech/azure_openai_text_to_speech_driver.py new file mode 100644 index 000000000..562a1d637 --- /dev/null +++ b/griptape/drivers/text_to_speech/azure_openai_text_to_speech_driver.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from typing import Callable, Optional + +import openai +from attrs import Factory, define, field + +from griptape.drivers import OpenAiTextToSpeechDriver + + +@define +class AzureOpenAiTextToSpeechDriver(OpenAiTextToSpeechDriver): + """Azure OpenAi Text to Speech Driver. + + Attributes: + azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name. + azure_endpoint: An Azure OpenAi endpoint. + azure_ad_token: An optional Azure Active Directory token. + azure_ad_token_provider: An optional Azure Active Directory token provider. + api_version: An Azure OpenAi API version. + client: An `openai.AzureOpenAI` client. + """ + + model: str = field(default="tts", kw_only=True, metadata={"serializable": True}) + azure_deployment: str = field( + kw_only=True, + default=Factory(lambda self: self.model, takes_self=True), + metadata={"serializable": True}, + ) + azure_endpoint: str = field(kw_only=True, metadata={"serializable": True}) + azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) + azure_ad_token_provider: Optional[Callable[[], str]] = field( + kw_only=True, + default=None, + metadata={"serializable": False}, + ) + api_version: str = field(default="2024-07-01-preview", kw_only=True, metadata={"serializable": True}) + client: openai.AzureOpenAI = field( + default=Factory( + lambda self: openai.AzureOpenAI( + organization=self.organization, + api_key=self.api_key, + api_version=self.api_version, + azure_endpoint=self.azure_endpoint, + azure_deployment=self.azure_deployment, + azure_ad_token=self.azure_ad_token, + azure_ad_token_provider=self.azure_ad_token_provider, + ), + takes_self=True, + ), + ) diff --git a/tests/unit/drivers/text_to_speech/test_azure_openai_text_to_speech_driver.py b/tests/unit/drivers/text_to_speech/test_azure_openai_text_to_speech_driver.py new file mode 100644 index 000000000..5bab87c9e --- /dev/null +++ b/tests/unit/drivers/text_to_speech/test_azure_openai_text_to_speech_driver.py @@ -0,0 +1,33 @@ +from unittest.mock import Mock + +import pytest + +from griptape.drivers import AzureOpenAiTextToSpeechDriver + + +class TestAzureOpenAiTextToSpeechDriver: + @pytest.fixture() + def mock_speech_create(self, mocker): + mock_speech_create = mocker.patch("openai.AzureOpenAI").return_value.audio.speech.create + mock_function = Mock(arguments='{"foo": "bar"}', id="mock-id") + mock_function.name = "MockTool_test" + mock_speech_create.return_value = Mock( + content=b"speech", + ) + + return mock_speech_create + + def test_init(self): + assert AzureOpenAiTextToSpeechDriver(azure_endpoint="foobar", azure_deployment="foobar") + assert AzureOpenAiTextToSpeechDriver(azure_endpoint="foobar").azure_deployment == "tts" + + def test_run_text_to_audio(self, mock_speech_create): + driver = AzureOpenAiTextToSpeechDriver(azure_endpoint="foobar") + output = driver.run_text_to_audio(["foo", "bar"]) + mock_speech_create.assert_called_once_with( + input="foo. bar", + model=driver.model, + response_format=driver.format, + voice=driver.voice, + ) + assert output.value == b"speech"