Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds voice_engine param to all TTS. #31

Merged
merged 1 commit into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ dmypy.json
# pytype static type analyzer
.pytype/

.flake8


##---------------------------------------------------
# Windows default .gitignore
Expand Down
32 changes: 24 additions & 8 deletions pyht/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ async def refresh_lease(self):
async def stream_tts_input(
self,
text_stream: AsyncGenerator[str, None] | AsyncIterable[str],
options: TTSOptions
options: TTSOptions,
voice_engine: str | None = None,
):
"""Stream input to Play.ht via the text_stream object."""
buffer = io.StringIO()
Expand All @@ -88,18 +89,19 @@ async def stream_tts_input(
buffer.write(" ") # normalize word spacing.
if SENTENCE_END_REGEX.match(t) is None:
continue
async for data in self.tts(buffer.getvalue(), options):
async for data in self.tts(buffer.getvalue(), options, voice_engine):
yield data
buffer = io.StringIO()
# If text_stream closes, send all remaining text, regardless of sentence structure.
if buffer.tell() > 0:
async for data in self.tts(buffer.getvalue(), options):
async for data in self.tts(buffer.getvalue(), options, voice_engine):
yield data

async def tts(
self,
text: str | list[str],
options: TTSOptions,
voice_engine: str | None = None,
context: AsyncContext | None = None
) -> AsyncIterable[bytes]:
await self.refresh_lease()
Expand All @@ -114,12 +116,16 @@ async def tts(
text = ensure_sentence_end(text)

quality = options.quality.lower()
_quality = api_pb2.QUALITY_DRAFT

if voice_engine == "PlayHT2.0" and quality != "faster":
_quality = api_pb2.QUALITY_MEDIUM

params = api_pb2.TtsParams(
text=text,
voice=options.voice,
format=options.format,
quality=api_pb2.QUALITY_DRAFT if quality == "faster" else api_pb2.QUALITY_MEDIUM,
quality=_quality,
temperature=options.temperature,
top_p=options.top_p,
sample_rate=options.sample_rate,
Expand All @@ -133,15 +139,19 @@ async def tts(
async for response in stream:
yield response.data

def get_stream_pair(self, options: TTSOptions) -> tuple['_InputStream', '_OutputStream']:
def get_stream_pair(
self,
options: TTSOptions,
voice_engine: str | None = None
) -> tuple['_InputStream', '_OutputStream']:
"""Get a linked pair of (input, output) streams.

These stream objects ARE NOT thread-safe. Coroutines using these stream objects must
run on the same thread.
"""
shared_q = asyncio.Queue()
return (
_InputStream(self, options, shared_q),
_InputStream(self, options, shared_q, voice_engine),
_OutputStream(shared_q)
)

Expand Down Expand Up @@ -218,11 +228,17 @@ class _InputStream:
input_stream += 'Add another sentence to the stream.'
input_stream.done()
"""
def __init__(self, client: AsyncClient, options: TTSOptions, q: asyncio.Queue[bytes | None]):
def __init__(
self,
client: AsyncClient,
options: TTSOptions,
q: asyncio.Queue[bytes | None],
voice_engine: str | None,
):
self._input = TextStream()

async def listen():
async for output in client.stream_tts_input(self._input, options):
async for output in client.stream_tts_input(self._input, options, voice_engine):
await q.put(output)
await q.put(None)

Expand Down
34 changes: 24 additions & 10 deletions pyht/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
user_id: str,
api_key: str,
auto_connect: bool = True,
advanced: "Client.AdvancedOptions" | None = None,
advanced: "Client.AdvancedOptions | None" = None,
):
assert user_id, "user_id is required"
assert api_key, "api_key is required"
Expand Down Expand Up @@ -94,7 +94,8 @@ def refresh_lease(self):
def stream_tts_input(
self,
text_stream: Generator[str, None, None] | Iterable[str],
options: TTSOptions
options: TTSOptions,
voice_engine: str | None = None
) -> Iterable[bytes]:
"""Stream input to Play.ht via the text_stream object."""
buffer = io.StringIO()
Expand All @@ -104,13 +105,18 @@ def stream_tts_input(
buffer.write(" ") # normalize word spacing.
if SENTENCE_END_REGEX.match(t) is None:
continue
yield from self.tts(buffer.getvalue(), options)
yield from self.tts(buffer.getvalue(), options, voice_engine)
buffer = io.StringIO()
# If text_stream closes, send all remaining text, regardless of sentence structure.
if buffer.tell() > 0:
yield from self.tts(buffer.getvalue(), options)
yield from self.tts(buffer.getvalue(), options, voice_engine)

def tts(self, text: str | List[str], options: TTSOptions) -> Iterable[bytes]:
def tts(
self,
text: str | List[str],
options: TTSOptions,
voice_engine: str | None = None
) -> Iterable[bytes]:
self.refresh_lease()
with self._lock:
assert self._lease is not None and self._rpc is not None, "No connection"
Expand All @@ -123,12 +129,16 @@ def tts(self, text: str | List[str], options: TTSOptions) -> Iterable[bytes]:
text = ensure_sentence_end(text)

quality = options.quality.lower()
_quality = api_pb2.QUALITY_DRAFT

if voice_engine == "PlayHT2.0" and quality != "faster":
_quality = api_pb2.QUALITY_MEDIUM

params = api_pb2.TtsParams(
text=text,
voice=options.voice,
format=options.format,
quality=api_pb2.QUALITY_DRAFT if quality == "faster" else api_pb2.QUALITY_MEDIUM,
quality=_quality,
temperature=options.temperature,
top_p=options.top_p,
sample_rate=options.sample_rate,
Expand All @@ -140,14 +150,18 @@ def tts(self, text: str | List[str], options: TTSOptions) -> Iterable[bytes]:
for item in response:
yield item.data

def get_stream_pair(self, options: TTSOptions) -> Tuple['_InputStream', '_OutputStream']:
def get_stream_pair(
self,
options: TTSOptions,
voice_engine: str | None = None
) -> Tuple['_InputStream', '_OutputStream']:
"""Get a linked pair of (input, output) streams.

These stream objects are thread-aware and safe to use in separate threads.
"""
shared_q = queue.Queue()
return (
_InputStream(self, options, shared_q),
_InputStream(self, options, shared_q, voice_engine),
_OutputStream(shared_q)
)

Expand Down Expand Up @@ -193,11 +207,11 @@ class _InputStream:
input_stream += 'Add another sentence to the stream.'
input_stream.done()
"""
def __init__(self, client: Client, options: TTSOptions, q: queue.Queue[bytes | None]):
def __init__(self, client: Client, options: TTSOptions, q: queue.Queue[bytes | None], voice_engine: str | None):
self._input = TextStream()

def listen():
for output in client.stream_tts_input(self._input, options):
for output in client.stream_tts_input(self._input, options, voice_engine):
q.put(output)
q.put(None)

Expand Down