Skip to content

Commit

Permalink
Refactor TextToSpeechDrivers
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo committed Sep 19, 2024
1 parent 9e9c304 commit cd8fd98
Show file tree
Hide file tree
Showing 20 changed files with 125 additions and 68 deletions.
15 changes: 9 additions & 6 deletions griptape/drivers/text_to_speech/base_text_to_speech_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,21 @@
@define
class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
model: str = field(kw_only=True, metadata={"serializable": True})
max_characters: int = field(kw_only=True, metadata={"serializable": True})

def before_run(self, prompts: list[str]) -> None:
EventBus.publish_event(StartTextToSpeechEvent(prompts=prompts))
def before_run(self, prompt: str) -> None:
if len(prompt) > self.max_characters:
raise ValueError(f"Prompt exceeds maximum character limit of {self.max_characters}")
EventBus.publish_event(StartTextToSpeechEvent(prompt=prompt))

def after_run(self) -> None:
EventBus.publish_event(FinishTextToSpeechEvent())

def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
def run_text_to_audio(self, prompt: str) -> AudioArtifact:
for attempt in self.retrying():
with attempt:
self.before_run(prompts)
result = self.try_text_to_audio(prompts)
self.before_run(prompt)
result = self.try_text_to_audio(prompt)
self.after_run()

return result
Expand All @@ -38,4 +41,4 @@ def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
raise Exception("Failed to run text to audio generation")

@abstractmethod
def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact: ...
def try_text_to_audio(self, prompt: str) -> AudioArtifact: ...
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
@define
class DummyTextToSpeechDriver(BaseTextToSpeechDriver):
model: None = field(init=False, default=None, kw_only=True)
max_characters: None = field(init=False, default=None, kw_only=True)

def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
def try_text_to_audio(self, prompt: str) -> AudioArtifact:
raise DummyError(__class__.__name__, "try_text_to_audio")
Original file line number Diff line number Diff line change
@@ -1,34 +1,54 @@
from __future__ import annotations

from typing import Any
from typing import TYPE_CHECKING, Union

from attrs import Factory, define, field

from griptape.artifacts.audio_artifact import AudioArtifact
from griptape.drivers import BaseTextToSpeechDriver
from griptape.utils import import_optional_dependency

if TYPE_CHECKING:
import elevenlabs
import elevenlabs.client


@define
class ElevenLabsTextToSpeechDriver(BaseTextToSpeechDriver):
api_key: str = field(kw_only=True, metadata={"serializable": True})
client: Any = field(
model: Union[str, elevenlabs.Model] = field(
default=None,
kw_only=True,
metadata={"serializable": True},
)
client: elevenlabs.client.ElevenLabs = field(
default=Factory(
lambda self: import_optional_dependency("elevenlabs.client").ElevenLabs(api_key=self.api_key),
takes_self=True,
),
kw_only=True,
metadata={"serializable": True},
)
voice: str = field(kw_only=True, metadata={"serializable": True})
voice: Union[str, elevenlabs.Voice] = field(
default=None,
kw_only=True,
metadata={"serializable": True},
)
output_format: str = field(default="mp3_44100_128", kw_only=True, metadata={"serializable": True})
max_characters: int = field(default=10_000, kw_only=True, metadata={"serializable": True})

def try_text_to_audio(self, prompt: str) -> AudioArtifact:
kwargs = {}
if self.model is not None:
kwargs["model"] = self.model
if self.voice is not None:
kwargs["voice"] = self.voice
if self.output_format is not None:
kwargs["output_format"] = self.output_format

def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
audio = self.client.generate(
text=". ".join(prompts),
voice=self.voice,
model=self.model,
output_format=self.output_format,
text=prompt,
**kwargs,
)

content = b""
Expand Down
14 changes: 11 additions & 3 deletions griptape/drivers/text_to_speech/openai_text_to_speech_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,32 @@ class OpenAiTextToSpeechDriver(BaseTextToSpeechDriver):
metadata={"serializable": True},
)
format: Literal["mp3", "opus", "aac", "flac"] = field(default="mp3", kw_only=True, metadata={"serializable": True})
speed: float = field(default=1.0, kw_only=True, metadata={"serializable": True})
api_type: Optional[str] = field(default=openai.api_type, kw_only=True)
api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True})
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True)
organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True})
max_characters: int = field(default=4096, kw_only=True, metadata={"serializable": True})
client: openai.OpenAI = field(
default=Factory(
lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
takes_self=True,
),
)

def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
@speed.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_speed(self, attribute: str, value: float) -> None:
if value < 0.25 or value > 4.0:
raise ValueError("Speed must be between 0.5 and 4.0")

def try_text_to_audio(self, prompt: str) -> AudioArtifact:
response = self.client.audio.speech.create(
input=". ".join(prompts),
voice=self.voice,
input=prompt,
model=self.model,
voice=self.voice,
response_format=self.format,
speed=self.speed,
)

return AudioArtifact(value=response.content, format=self.format)
4 changes: 2 additions & 2 deletions griptape/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from .image.inpainting_image_generation_engine import InpaintingImageGenerationEngine
from .image.outpainting_image_generation_engine import OutpaintingImageGenerationEngine
from .image_query.image_query_engine import ImageQueryEngine
from .audio.text_to_speech_engine import TextToSpeechEngine
from .audio.audio_transcription_engine import AudioTranscriptionEngine
from .text_to_speech.text_to_speech_engine import TextToSpeechEngine
from .audio_transcription.audio_transcription_engine import AudioTranscriptionEngine

__all__ = [
"BaseSummaryEngine",
Expand Down
21 changes: 0 additions & 21 deletions griptape/engines/audio/text_to_speech_engine.py

This file was deleted.

File renamed without changes.
Empty file.
27 changes: 27 additions & 0 deletions griptape/engines/text_to_speech/text_to_speech_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from attrs import Factory, define, field

from griptape.configs import Defaults

if TYPE_CHECKING:
from griptape.artifacts.audio_artifact import AudioArtifact
from griptape.drivers import BaseTextToSpeechDriver


@define
class TextToSpeechEngine:
text_to_speech_driver: BaseTextToSpeechDriver = field(
default=Factory(lambda: Defaults.drivers_config.text_to_speech_driver), kw_only=True
)
prompt_joiner: str = field(default=". ", kw_only=True)

def run(self, prompts: list[str], *args, **kwargs) -> list[AudioArtifact]:
prompt_str = self.prompt_joiner.join(prompts).strip()
new_prompts = [
prompt_str[i : i + self.text_to_speech_driver.max_characters]
for i in range(0, len(prompt_str), self.text_to_speech_driver.max_characters)
]
return [self.text_to_speech_driver.try_text_to_audio(prompt=prompt) for prompt in new_prompts]
2 changes: 1 addition & 1 deletion griptape/events/start_text_to_speech_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@

@define
class StartTextToSpeechEvent(BaseTextToSpeechEvent):
prompts: list[str] = field(kw_only=True, metadata={"serializable": True})
prompt: str = field(kw_only=True, metadata={"serializable": True})
11 changes: 6 additions & 5 deletions griptape/tasks/text_to_speech_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from attrs import Factory, define, field

from griptape.artifacts import TextArtifact
from griptape.artifacts import ListArtifact, TextArtifact
from griptape.engines import TextToSpeechEngine
from griptape.tasks.base_audio_generation_task import BaseAudioGenerationTask
from griptape.utils import J2
Expand Down Expand Up @@ -34,10 +34,11 @@ def input(self) -> TextArtifact:
def input(self, value: TextArtifact) -> None:
self._input = value

def run(self) -> AudioArtifact:
audio_artifact = self.text_to_speech_engine.run(prompts=[self.input.to_text()], rulesets=self.all_rulesets)
def run(self) -> ListArtifact[AudioArtifact]:
audio_artifacts = self.text_to_speech_engine.run(prompts=[self.input.to_text()], rulesets=self.all_rulesets)

if self.output_dir or self.output_file:
self._write_to_file(audio_artifact)
for audio_artifact in audio_artifacts:
self._write_to_file(audio_artifact)

return audio_artifact
return ListArtifact(audio_artifacts)
12 changes: 7 additions & 5 deletions griptape/tools/text_to_speech/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from attrs import define, field
from schema import Literal, Schema

from griptape.artifacts import ListArtifact
from griptape.mixins.artifact_file_output_mixin import ArtifactFileOutputMixin
from griptape.tools import BaseTool
from griptape.utils.decorators import activity

if TYPE_CHECKING:
from griptape.artifacts import AudioArtifact, ErrorArtifact
from griptape.artifacts import AudioArtifact
from griptape.engines import TextToSpeechEngine


Expand All @@ -32,12 +33,13 @@ class TextToSpeechTool(ArtifactFileOutputMixin, BaseTool):
"schema": Schema({Literal("text", description="The literal text to be converted to speech."): str}),
},
)
def text_to_speech(self, params: dict[str, Any]) -> AudioArtifact | ErrorArtifact:
def text_to_speech(self, params: dict[str, Any]) -> ListArtifact[AudioArtifact]:
text = params["values"]["text"]

output_artifact = self.engine.run(prompts=[text])
output_artifacts = self.engine.run(prompts=[text])

if self.output_dir or self.output_file:
self._write_to_file(output_artifact)
for audio_artifact in output_artifacts:
self._write_to_file(audio_artifact)

return output_artifact
return ListArtifact(output_artifacts)
3 changes: 2 additions & 1 deletion tests/mocks/mock_text_to_speech_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class MockTextToSpeechDriver(BaseTextToSpeechDriver):
model: str = field(default="test-model", kw_only=True)
mock_output: str = field(default="mock output", kw_only=True)
max_characters: int = field(default=100, kw_only=True)

def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
def try_text_to_audio(self, prompt: str) -> AudioArtifact:
return AudioArtifact(value=self.mock_output, format="mp3")
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def test_to_dict(self, config):
"organization": None,
"type": "AzureOpenAiTextToSpeechDriver",
"voice": "alloy",
"speed": 1.0,
"max_characters": 4096,
},
"audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"},
}
2 changes: 2 additions & 0 deletions tests/unit/configs/drivers/test_openai_driver_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def test_to_dict(self, config):
"model": "tts-1",
"organization": None,
"voice": "alloy",
"speed": 1.0,
"max_characters": 4096,
},
"audio_transcription_driver": {
"type": "OpenAiAudioTranscriptionDriver",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,21 @@ def test_init(self):

def test_run_text_to_audio(self, mock_speech_create):
driver = AzureOpenAiTextToSpeechDriver(azure_endpoint="foobar")
output = driver.run_text_to_audio(["foo", "bar"])
mock_speech_create.assert_called_once_with(
input="foo. bar",
output1 = driver.run_text_to_audio("foo")
mock_speech_create.assert_called_with(
input="foo",
model=driver.model,
response_format=driver.format,
voice=driver.voice,
speed=driver.speed,
)
assert output.value == b"speech"
output2 = driver.run_text_to_audio("bar")
mock_speech_create.assert_called_with(
input="bar",
model=driver.model,
response_format=driver.format,
voice=driver.voice,
speed=driver.speed,
)
assert output1.value == b"speech"
assert output2.value == b"speech"
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_init(self, driver):
def test_try_text_to_audio(self, driver):
driver.client.generate.return_value = [b"audio data"]

audio_artifact = driver.try_text_to_audio(prompts=["test prompt"])
audio_artifact = driver.try_text_to_audio(prompt="test prompt")

assert audio_artifact.value == b"audio data"
assert audio_artifact.format == "mp3"
11 changes: 6 additions & 5 deletions tests/unit/tasks/test_text_to_speech_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from unittest.mock import Mock

from griptape.artifacts import AudioArtifact, TextArtifact
from griptape.artifacts import AudioArtifact, ListArtifact, TextArtifact
from griptape.engines import TextToSpeechEngine
from griptape.structures import Agent, Pipeline
from griptape.tasks import BaseTask, TextToSpeechTask
Expand Down Expand Up @@ -30,16 +30,17 @@ def test_config_text_to_speech_engine(self):

def test_calls(self):
text_to_speech_engine = Mock()
text_to_speech_engine.run.return_value = AudioArtifact(b"audio content", format="mp3")
text_to_speech_engine.run.return_value = ListArtifact([AudioArtifact(b"audio content", format="mp3")])

assert TextToSpeechTask("test", text_to_speech_engine=text_to_speech_engine).run().value == b"audio content"
assert TextToSpeechTask("test", text_to_speech_engine=text_to_speech_engine).run()[0].value == b"audio content"

def test_run(self):
text_to_speech_engine = Mock()
text_to_speech_engine.run.return_value = AudioArtifact(b"audio content", format="mp3")
text_to_speech_engine.run.return_value = ListArtifact([AudioArtifact(b"audio content", format="mp3")])

task = TextToSpeechTask("some text", text_to_speech_engine=text_to_speech_engine)
pipeline = Pipeline()
pipeline.add_task(task)

assert isinstance(pipeline.run().output, AudioArtifact)
assert isinstance(pipeline.run().output, ListArtifact)
assert isinstance(pipeline.run().output[0], AudioArtifact) # pyright: ignore[reportIndexIssue, reportOptionalSubscript]
10 changes: 5 additions & 5 deletions tests/unit/tools/test_text_to_speech_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@ def test_validate_output_configs(self, text_to_speech_engine) -> None:
def test_text_to_speech(self, text_to_speech_client) -> None:
text_to_speech_client.engine.run.return_value = Mock(value=b"audio data", format="mp3")

audio_artifact = text_to_speech_client.text_to_speech(params={"values": {"text": "say this!"}})
list_audio_artifact = text_to_speech_client.text_to_speech(params={"values": {"text": "say this!"}})

assert audio_artifact
assert list_audio_artifact

def test_text_to_speech_with_outfile(self, text_to_speech_engine) -> None:
outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.mp3"
text_to_speech_client = TextToSpeechTool(engine=text_to_speech_engine, output_file=outfile)

text_to_speech_client.engine.run.return_value = AudioArtifact(value=b"audio data", format="mp3") # pyright: ignore[reportFunctionMemberAccess]
text_to_speech_client.engine.run.return_value = [AudioArtifact(value=b"audio data", format="mp3")] # pyright: ignore[reportFunctionMemberAccess]

audio_artifact = text_to_speech_client.text_to_speech(params={"values": {"text": "say this!"}})
list_audio_artifact = text_to_speech_client.text_to_speech(params={"values": {"text": "say this!"}})

assert audio_artifact
assert list_audio_artifact
assert os.path.exists(outfile)

0 comments on commit cd8fd98

Please sign in to comment.