diff --git a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py index b2ad8bc3e..170dbc2d7 100644 --- a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py @@ -18,18 +18,21 @@ @define class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) + max_characters: int = field(kw_only=True, metadata={"serializable": True}) - def before_run(self, prompts: list[str]) -> None: - EventBus.publish_event(StartTextToSpeechEvent(prompts=prompts)) + def before_run(self, prompt: str) -> None: + if len(prompt) > self.max_characters: + raise ValueError(f"Prompt exceeds maximum character limit of {self.max_characters}") + EventBus.publish_event(StartTextToSpeechEvent(prompt=prompt)) def after_run(self) -> None: EventBus.publish_event(FinishTextToSpeechEvent()) - def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact: + def run_text_to_audio(self, prompt: str) -> AudioArtifact: for attempt in self.retrying(): with attempt: - self.before_run(prompts) - result = self.try_text_to_audio(prompts) + self.before_run(prompt) + result = self.try_text_to_audio(prompt) self.after_run() return result @@ -38,4 +41,4 @@ def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact: raise Exception("Failed to run text to audio generation") @abstractmethod - def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact: ... + def try_text_to_audio(self, prompt: str) -> AudioArtifact: ... diff --git a/griptape/drivers/text_to_speech/dummy_text_to_speech_driver.py b/griptape/drivers/text_to_speech/dummy_text_to_speech_driver.py index 0f5842bfe..067a85dce 100644 --- a/griptape/drivers/text_to_speech/dummy_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/dummy_text_to_speech_driver.py @@ -14,6 +14,7 @@ @define class DummyTextToSpeechDriver(BaseTextToSpeechDriver): model: None = field(init=False, default=None, kw_only=True) + max_characters: None = field(init=False, default=None, kw_only=True) - def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact: + def try_text_to_audio(self, prompt: str) -> AudioArtifact: raise DummyError(__class__.__name__, "try_text_to_audio") diff --git a/griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py b/griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py index f4be58162..4a037084d 100644 --- a/griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Union from attrs import Factory, define, field @@ -8,11 +8,20 @@ from griptape.drivers import BaseTextToSpeechDriver from griptape.utils import import_optional_dependency +if TYPE_CHECKING: + import elevenlabs + import elevenlabs.client + @define class ElevenLabsTextToSpeechDriver(BaseTextToSpeechDriver): api_key: str = field(kw_only=True, metadata={"serializable": True}) - client: Any = field( + model: Union[str, elevenlabs.Model] = field( + default=None, + kw_only=True, + metadata={"serializable": True}, + ) + client: elevenlabs.client.ElevenLabs = field( default=Factory( lambda self: import_optional_dependency("elevenlabs.client").ElevenLabs(api_key=self.api_key), takes_self=True, @@ -20,15 +29,26 @@ class ElevenLabsTextToSpeechDriver(BaseTextToSpeechDriver): kw_only=True, metadata={"serializable": True}, ) - voice: str = field(kw_only=True, metadata={"serializable": True}) + voice: Union[str, elevenlabs.Voice] = field( + default=None, + kw_only=True, + metadata={"serializable": True}, + ) output_format: str = field(default="mp3_44100_128", kw_only=True, metadata={"serializable": True}) + max_characters: int = field(default=10_000, kw_only=True, metadata={"serializable": True}) + + def try_text_to_audio(self, prompt: str) -> AudioArtifact: + kwargs = {} + if self.model is not None: + kwargs["model"] = self.model + if self.voice is not None: + kwargs["voice"] = self.voice + if self.output_format is not None: + kwargs["output_format"] = self.output_format - def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact: audio = self.client.generate( - text=". ".join(prompts), - voice=self.voice, - model=self.model, - output_format=self.output_format, + text=prompt, + **kwargs, ) content = b"" diff --git a/griptape/drivers/text_to_speech/openai_text_to_speech_driver.py b/griptape/drivers/text_to_speech/openai_text_to_speech_driver.py index 543ef1ec7..58ec9d2ed 100644 --- a/griptape/drivers/text_to_speech/openai_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/openai_text_to_speech_driver.py @@ -18,11 +18,13 @@ class OpenAiTextToSpeechDriver(BaseTextToSpeechDriver): metadata={"serializable": True}, ) format: Literal["mp3", "opus", "aac", "flac"] = field(default="mp3", kw_only=True, metadata={"serializable": True}) + speed: float = field(default=1.0, kw_only=True, metadata={"serializable": True}) api_type: Optional[str] = field(default=openai.api_type, kw_only=True) api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True}) base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True) organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True}) + max_characters: int = field(default=4096, kw_only=True, metadata={"serializable": True}) client: openai.OpenAI = field( default=Factory( lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), @@ -30,12 +32,18 @@ class OpenAiTextToSpeechDriver(BaseTextToSpeechDriver): ), ) - def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact: + @speed.validator # pyright: ignore[reportAttributeAccessIssue] + def validate_speed(self, attribute: str, value: float) -> None: + if value < 0.25 or value > 4.0: + raise ValueError("Speed must be between 0.5 and 4.0") + + def try_text_to_audio(self, prompt: str) -> AudioArtifact: response = self.client.audio.speech.create( - input=". ".join(prompts), - voice=self.voice, + input=prompt, model=self.model, + voice=self.voice, response_format=self.format, + speed=self.speed, ) return AudioArtifact(value=response.content, format=self.format) diff --git a/griptape/engines/__init__.py b/griptape/engines/__init__.py index 7835b2238..d097d9252 100644 --- a/griptape/engines/__init__.py +++ b/griptape/engines/__init__.py @@ -9,8 +9,8 @@ from .image.inpainting_image_generation_engine import InpaintingImageGenerationEngine from .image.outpainting_image_generation_engine import OutpaintingImageGenerationEngine from .image_query.image_query_engine import ImageQueryEngine -from .audio.text_to_speech_engine import TextToSpeechEngine -from .audio.audio_transcription_engine import AudioTranscriptionEngine +from .text_to_speech.text_to_speech_engine import TextToSpeechEngine +from .audio_transcription.audio_transcription_engine import AudioTranscriptionEngine __all__ = [ "BaseSummaryEngine", diff --git a/griptape/engines/audio/text_to_speech_engine.py b/griptape/engines/audio/text_to_speech_engine.py deleted file mode 100644 index 1261ae369..000000000 --- a/griptape/engines/audio/text_to_speech_engine.py +++ /dev/null @@ -1,21 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from attrs import Factory, define, field - -from griptape.configs import Defaults - -if TYPE_CHECKING: - from griptape.artifacts.audio_artifact import AudioArtifact - from griptape.drivers import BaseTextToSpeechDriver - - -@define -class TextToSpeechEngine: - text_to_speech_driver: BaseTextToSpeechDriver = field( - default=Factory(lambda: Defaults.drivers_config.text_to_speech_driver), kw_only=True - ) - - def run(self, prompts: list[str], *args, **kwargs) -> AudioArtifact: - return self.text_to_speech_driver.try_text_to_audio(prompts=prompts) diff --git a/griptape/engines/audio/__init__.py b/griptape/engines/audio_transcription/__init__.py similarity index 100% rename from griptape/engines/audio/__init__.py rename to griptape/engines/audio_transcription/__init__.py diff --git a/griptape/engines/audio/audio_transcription_engine.py b/griptape/engines/audio_transcription/audio_transcription_engine.py similarity index 100% rename from griptape/engines/audio/audio_transcription_engine.py rename to griptape/engines/audio_transcription/audio_transcription_engine.py diff --git a/griptape/engines/text_to_speech/__init__.py b/griptape/engines/text_to_speech/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/engines/text_to_speech/text_to_speech_engine.py b/griptape/engines/text_to_speech/text_to_speech_engine.py new file mode 100644 index 000000000..35521e9cd --- /dev/null +++ b/griptape/engines/text_to_speech/text_to_speech_engine.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from attrs import Factory, define, field + +from griptape.configs import Defaults + +if TYPE_CHECKING: + from griptape.artifacts.audio_artifact import AudioArtifact + from griptape.drivers import BaseTextToSpeechDriver + + +@define +class TextToSpeechEngine: + text_to_speech_driver: BaseTextToSpeechDriver = field( + default=Factory(lambda: Defaults.drivers_config.text_to_speech_driver), kw_only=True + ) + prompt_joiner: str = field(default=". ", kw_only=True) + + def run(self, prompts: list[str], *args, **kwargs) -> list[AudioArtifact]: + prompt_str = self.prompt_joiner.join(prompts).strip() + new_prompts = [ + prompt_str[i : i + self.text_to_speech_driver.max_characters] + for i in range(0, len(prompt_str), self.text_to_speech_driver.max_characters) + ] + return [self.text_to_speech_driver.try_text_to_audio(prompt=prompt) for prompt in new_prompts] diff --git a/griptape/events/start_text_to_speech_event.py b/griptape/events/start_text_to_speech_event.py index 4c3f27ca0..6e150c520 100644 --- a/griptape/events/start_text_to_speech_event.py +++ b/griptape/events/start_text_to_speech_event.py @@ -7,4 +7,4 @@ @define class StartTextToSpeechEvent(BaseTextToSpeechEvent): - prompts: list[str] = field(kw_only=True, metadata={"serializable": True}) + prompt: str = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/tasks/text_to_speech_task.py b/griptape/tasks/text_to_speech_task.py index 680a67603..b2159518b 100644 --- a/griptape/tasks/text_to_speech_task.py +++ b/griptape/tasks/text_to_speech_task.py @@ -4,7 +4,7 @@ from attrs import Factory, define, field -from griptape.artifacts import TextArtifact +from griptape.artifacts import ListArtifact, TextArtifact from griptape.engines import TextToSpeechEngine from griptape.tasks.base_audio_generation_task import BaseAudioGenerationTask from griptape.utils import J2 @@ -34,10 +34,11 @@ def input(self) -> TextArtifact: def input(self, value: TextArtifact) -> None: self._input = value - def run(self) -> AudioArtifact: - audio_artifact = self.text_to_speech_engine.run(prompts=[self.input.to_text()], rulesets=self.all_rulesets) + def run(self) -> ListArtifact[AudioArtifact]: + audio_artifacts = self.text_to_speech_engine.run(prompts=[self.input.to_text()], rulesets=self.all_rulesets) if self.output_dir or self.output_file: - self._write_to_file(audio_artifact) + for audio_artifact in audio_artifacts: + self._write_to_file(audio_artifact) - return audio_artifact + return ListArtifact(audio_artifacts) diff --git a/griptape/tools/text_to_speech/tool.py b/griptape/tools/text_to_speech/tool.py index aca259698..3f2e7bcc2 100644 --- a/griptape/tools/text_to_speech/tool.py +++ b/griptape/tools/text_to_speech/tool.py @@ -5,12 +5,13 @@ from attrs import define, field from schema import Literal, Schema +from griptape.artifacts import ListArtifact from griptape.mixins.artifact_file_output_mixin import ArtifactFileOutputMixin from griptape.tools import BaseTool from griptape.utils.decorators import activity if TYPE_CHECKING: - from griptape.artifacts import AudioArtifact, ErrorArtifact + from griptape.artifacts import AudioArtifact from griptape.engines import TextToSpeechEngine @@ -32,12 +33,13 @@ class TextToSpeechTool(ArtifactFileOutputMixin, BaseTool): "schema": Schema({Literal("text", description="The literal text to be converted to speech."): str}), }, ) - def text_to_speech(self, params: dict[str, Any]) -> AudioArtifact | ErrorArtifact: + def text_to_speech(self, params: dict[str, Any]) -> ListArtifact[AudioArtifact]: text = params["values"]["text"] - output_artifact = self.engine.run(prompts=[text]) + output_artifacts = self.engine.run(prompts=[text]) if self.output_dir or self.output_file: - self._write_to_file(output_artifact) + for audio_artifact in output_artifacts: + self._write_to_file(audio_artifact) - return output_artifact + return ListArtifact(output_artifacts) diff --git a/tests/mocks/mock_text_to_speech_driver.py b/tests/mocks/mock_text_to_speech_driver.py index 14be84b51..e899eff16 100644 --- a/tests/mocks/mock_text_to_speech_driver.py +++ b/tests/mocks/mock_text_to_speech_driver.py @@ -10,6 +10,7 @@ class MockTextToSpeechDriver(BaseTextToSpeechDriver): model: str = field(default="test-model", kw_only=True) mock_output: str = field(default="mock output", kw_only=True) + max_characters: int = field(default=100, kw_only=True) - def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact: + def try_text_to_audio(self, prompt: str) -> AudioArtifact: return AudioArtifact(value=self.mock_output, format="mp3") diff --git a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py index 83b9dd77c..dffca5d5d 100644 --- a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py +++ b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py @@ -95,6 +95,8 @@ def test_to_dict(self, config): "organization": None, "type": "AzureOpenAiTextToSpeechDriver", "voice": "alloy", + "speed": 1.0, + "max_characters": 4096, }, "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, } diff --git a/tests/unit/configs/drivers/test_openai_driver_config.py b/tests/unit/configs/drivers/test_openai_driver_config.py index 016383c32..5b18aa30c 100644 --- a/tests/unit/configs/drivers/test_openai_driver_config.py +++ b/tests/unit/configs/drivers/test_openai_driver_config.py @@ -75,6 +75,8 @@ def test_to_dict(self, config): "model": "tts-1", "organization": None, "voice": "alloy", + "speed": 1.0, + "max_characters": 4096, }, "audio_transcription_driver": { "type": "OpenAiAudioTranscriptionDriver", diff --git a/tests/unit/drivers/text_to_speech/test_azure_openai_text_to_speech_driver.py b/tests/unit/drivers/text_to_speech/test_azure_openai_text_to_speech_driver.py index 5bab87c9e..fba39d815 100644 --- a/tests/unit/drivers/text_to_speech/test_azure_openai_text_to_speech_driver.py +++ b/tests/unit/drivers/text_to_speech/test_azure_openai_text_to_speech_driver.py @@ -23,11 +23,21 @@ def test_init(self): def test_run_text_to_audio(self, mock_speech_create): driver = AzureOpenAiTextToSpeechDriver(azure_endpoint="foobar") - output = driver.run_text_to_audio(["foo", "bar"]) - mock_speech_create.assert_called_once_with( - input="foo. bar", + output1 = driver.run_text_to_audio("foo") + mock_speech_create.assert_called_with( + input="foo", model=driver.model, response_format=driver.format, voice=driver.voice, + speed=driver.speed, ) - assert output.value == b"speech" + output2 = driver.run_text_to_audio("bar") + mock_speech_create.assert_called_with( + input="bar", + model=driver.model, + response_format=driver.format, + voice=driver.voice, + speed=driver.speed, + ) + assert output1.value == b"speech" + assert output2.value == b"speech" diff --git a/tests/unit/drivers/text_to_speech/test_elevenlabs_audio_generation_driver.py b/tests/unit/drivers/text_to_speech/test_elevenlabs_audio_generation_driver.py index 26c29adcf..4084ca94c 100644 --- a/tests/unit/drivers/text_to_speech/test_elevenlabs_audio_generation_driver.py +++ b/tests/unit/drivers/text_to_speech/test_elevenlabs_audio_generation_driver.py @@ -16,7 +16,7 @@ def test_init(self, driver): def test_try_text_to_audio(self, driver): driver.client.generate.return_value = [b"audio data"] - audio_artifact = driver.try_text_to_audio(prompts=["test prompt"]) + audio_artifact = driver.try_text_to_audio(prompt="test prompt") assert audio_artifact.value == b"audio data" assert audio_artifact.format == "mp3" diff --git a/tests/unit/tasks/test_text_to_speech_task.py b/tests/unit/tasks/test_text_to_speech_task.py index 44348fef0..faaa5ba7c 100644 --- a/tests/unit/tasks/test_text_to_speech_task.py +++ b/tests/unit/tasks/test_text_to_speech_task.py @@ -1,6 +1,6 @@ from unittest.mock import Mock -from griptape.artifacts import AudioArtifact, TextArtifact +from griptape.artifacts import AudioArtifact, ListArtifact, TextArtifact from griptape.engines import TextToSpeechEngine from griptape.structures import Agent, Pipeline from griptape.tasks import BaseTask, TextToSpeechTask @@ -30,16 +30,17 @@ def test_config_text_to_speech_engine(self): def test_calls(self): text_to_speech_engine = Mock() - text_to_speech_engine.run.return_value = AudioArtifact(b"audio content", format="mp3") + text_to_speech_engine.run.return_value = ListArtifact([AudioArtifact(b"audio content", format="mp3")]) - assert TextToSpeechTask("test", text_to_speech_engine=text_to_speech_engine).run().value == b"audio content" + assert TextToSpeechTask("test", text_to_speech_engine=text_to_speech_engine).run()[0].value == b"audio content" def test_run(self): text_to_speech_engine = Mock() - text_to_speech_engine.run.return_value = AudioArtifact(b"audio content", format="mp3") + text_to_speech_engine.run.return_value = ListArtifact([AudioArtifact(b"audio content", format="mp3")]) task = TextToSpeechTask("some text", text_to_speech_engine=text_to_speech_engine) pipeline = Pipeline() pipeline.add_task(task) - assert isinstance(pipeline.run().output, AudioArtifact) + assert isinstance(pipeline.run().output, ListArtifact) + assert isinstance(pipeline.run().output[0], AudioArtifact) # pyright: ignore[reportIndexIssue, reportOptionalSubscript] diff --git a/tests/unit/tools/test_text_to_speech_tool.py b/tests/unit/tools/test_text_to_speech_tool.py index 6f2c43bd3..ee99cf445 100644 --- a/tests/unit/tools/test_text_to_speech_tool.py +++ b/tests/unit/tools/test_text_to_speech_tool.py @@ -25,17 +25,17 @@ def test_validate_output_configs(self, text_to_speech_engine) -> None: def test_text_to_speech(self, text_to_speech_client) -> None: text_to_speech_client.engine.run.return_value = Mock(value=b"audio data", format="mp3") - audio_artifact = text_to_speech_client.text_to_speech(params={"values": {"text": "say this!"}}) + list_audio_artifact = text_to_speech_client.text_to_speech(params={"values": {"text": "say this!"}}) - assert audio_artifact + assert list_audio_artifact def test_text_to_speech_with_outfile(self, text_to_speech_engine) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.mp3" text_to_speech_client = TextToSpeechTool(engine=text_to_speech_engine, output_file=outfile) - text_to_speech_client.engine.run.return_value = AudioArtifact(value=b"audio data", format="mp3") # pyright: ignore[reportFunctionMemberAccess] + text_to_speech_client.engine.run.return_value = [AudioArtifact(value=b"audio data", format="mp3")] # pyright: ignore[reportFunctionMemberAccess] - audio_artifact = text_to_speech_client.text_to_speech(params={"values": {"text": "say this!"}}) + list_audio_artifact = text_to_speech_client.text_to_speech(params={"values": {"text": "say this!"}}) - assert audio_artifact + assert list_audio_artifact assert os.path.exists(outfile)