-
Notifications
You must be signed in to change notification settings - Fork 475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
upgrade to latest cartesia 1.0.3 #587
Changes from 2 commits
bfbd1c8
5fcb20a
abde468
c2c68e6
9b5d5ca
c91e2aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,31 +19,39 @@ def __init__( | |
|
||
# Lazy import the cartesia module | ||
try: | ||
from cartesia.tts import AsyncCartesiaTTS | ||
from cartesia import AsyncCartesia | ||
except ImportError as e: | ||
raise ImportError( | ||
f"Missing required dependancies for CartesiaSynthesizer" | ||
) from e | ||
|
||
self.cartesia_tts = AsyncCartesiaTTS | ||
|
||
self.api_key = synthesizer_config.api_key or getenv("CARTESIA_API_KEY") | ||
if not self.api_key: | ||
raise ValueError("Missing Cartesia API key") | ||
|
||
self.cartesia_tts = AsyncCartesia | ||
|
||
if synthesizer_config.audio_encoding == AudioEncoding.LINEAR16: | ||
self.channel_width = 2 | ||
match synthesizer_config.sampling_rate: | ||
case SamplingRate.RATE_44100: | ||
self.sampling_rate = 44100 | ||
self.output_format = "pcm_44100" | ||
self.output_format = { | ||
"sample_rate": 44100, | ||
"encoding": "pcm_s16le", | ||
"container": "raw", | ||
} | ||
case SamplingRate.RATE_22050: | ||
self.sampling_rate = 22050 | ||
self.output_format = "pcm_22050" | ||
self.output_format = { | ||
"sample_rate": 22050, | ||
"encoding": "pcm_s16le", | ||
"container": "raw", | ||
} | ||
case SamplingRate.RATE_16000: | ||
self.sampling_rate = 16000 | ||
self.output_format = "pcm_16000" | ||
self.output_format = { | ||
"sample_rate": 16000, | ||
"encoding": "pcm_s16le", | ||
"container": "raw", | ||
} | ||
case _: | ||
raise ValueError( | ||
f"Unsupported PCM sampling rate {synthesizer_config.sampling_rate}" | ||
|
@@ -52,41 +60,44 @@ def __init__( | |
# Cartesia has issues with MuLaw/8000. Use pcm/16000 and | ||
# create_synthesis_result_from_wav will handle the conversion to mulaw/8000 | ||
self.channel_width = 2 | ||
self.output_format = "pcm_16000" | ||
self.sampling_rate = 16000 | ||
self.output_format = { | ||
"sample_rate": 16000, | ||
"encoding": "pcm_s16le", | ||
"container": "raw", | ||
} | ||
else: | ||
raise ValueError( | ||
f"Unsupported audio encoding {synthesizer_config.audio_encoding}" | ||
) | ||
|
||
if not isinstance(self.output_format["sample_rate"], int): | ||
raise ValueError(f"Invalid type for sample_rate") | ||
self.sampling_rate = self.output_format["sample_rate"] | ||
self.num_channels = 1 | ||
self.model_id = synthesizer_config.model_id | ||
self.voice_id = synthesizer_config.voice_id | ||
self.client = self.cartesia_tts(api_key=self.api_key) | ||
self.voice_embedding = self.client.get_voice_embedding(voice_id=self.voice_id) | ||
|
||
|
||
async def create_speech_uncached( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if interested, would love to support you making this properly streaming!! would be something like:
you would have to also make the synthesizer support mulaw properly (which I think they fixed recently!) |
||
self, | ||
message: BaseMessage, | ||
chunk_size: int, | ||
is_first_text_chunk: bool = False, | ||
is_sole_text_chunk: bool = False, | ||
) -> SynthesisResult: | ||
generator = await self.client.generate( | ||
generator = await self.client.tts.sse( | ||
model_id=self.model_id, | ||
transcript=message.text, | ||
voice=self.voice_embedding, | ||
voice_id=self.voice_id, | ||
stream=True, | ||
model_id=self.model_id, | ||
data_rtype='bytes', | ||
output_format=self.output_format | ||
) | ||
|
||
audio_file = io.BytesIO() | ||
with wave.open(audio_file, 'wb') as wav_file: | ||
wav_file.setnchannels(self.num_channels) | ||
wav_file.setsampwidth(self.channel_width) | ||
wav_file.setframerate(self.sampling_rate) | ||
wav_file.setframerate(float(self.sampling_rate)) | ||
async for chunk in generator: | ||
wav_file.writeframes(chunk['audio']) | ||
audio_file.seek(0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this go back to being optional?