Skip to content

Commit

Permalink
change synthesizer factory api to take in aiohttp session (#318)
Browse files Browse the repository at this point in the history
* change synthesizer factory api

* update synthesizer constructors
  • Loading branch information
ajar98 authored Jul 30, 2023
1 parent 4ce40fc commit 8fb11e5
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 18 deletions.
2 changes: 1 addition & 1 deletion vocode/streaming/agent/restful_user_implemented_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def respond(
) -> Tuple[Optional[str], bool]:
config = self.agent_config.respond
try:
# todo: cache session
# TODO: cache session
async with aiohttp.ClientSession() as session:
payload = RESTfulAgentInput(
human_input=human_input, conversation_id=conversation_id
Expand Down
4 changes: 3 additions & 1 deletion vocode/streaming/synthesizer/azure_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
from typing import Any, List, Optional, Tuple
from xml.etree import ElementTree
import aiohttp
from vocode import getenv
from opentelemetry.context.context import Context

Expand Down Expand Up @@ -62,8 +63,9 @@ def __init__(
logger: Optional[logging.Logger] = None,
azure_speech_key: Optional[str] = None,
azure_speech_region: Optional[str] = None,
aiohttp_session: Optional[aiohttp.ClientSession] = None,
):
super().__init__(synthesizer_config)
super().__init__(synthesizer_config, aiohttp_session)
# Instantiates a client
azure_speech_key = azure_speech_key or getenv("AZURE_SPEECH_KEY")
azure_speech_region = azure_speech_region or getenv("AZURE_SPEECH_REGION")
Expand Down
4 changes: 3 additions & 1 deletion vocode/streaming/synthesizer/coqui_tts_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from concurrent.futures import ThreadPoolExecutor
import logging
from typing import Optional
import aiohttp
from pydub import AudioSegment
import numpy as np
import io
Expand All @@ -28,8 +29,9 @@ def __init__(
self,
synthesizer_config: CoquiTTSSynthesizerConfig,
logger: Optional[logging.Logger] = None,
aiohttp_session: Optional[aiohttp.ClientSession] = None,
):
super().__init__(synthesizer_config)
super().__init__(synthesizer_config, aiohttp_session)

from TTS.api import TTS

Expand Down
3 changes: 2 additions & 1 deletion vocode/streaming/synthesizer/eleven_labs_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ def __init__(
self,
synthesizer_config: ElevenLabsSynthesizerConfig,
logger: Optional[logging.Logger] = None,
aiohttp_session: Optional[aiohttp.ClientSession] = None,
):
super().__init__(synthesizer_config)
super().__init__(synthesizer_config, aiohttp_session)

import elevenlabs

Expand Down
34 changes: 26 additions & 8 deletions vocode/streaming/synthesizer/factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import Optional
import typing
import aiohttp

from vocode.streaming.models.synthesizer import (
AzureSynthesizerConfig,
Expand Down Expand Up @@ -31,22 +32,39 @@ def create_synthesizer(
self,
synthesizer_config: SynthesizerConfig,
logger: Optional[logging.Logger] = None,
aiohttp_session: Optional[aiohttp.ClientSession] = None,
):
if isinstance(synthesizer_config, GoogleSynthesizerConfig):
return GoogleSynthesizer(synthesizer_config, logger=logger)
return GoogleSynthesizer(
synthesizer_config, logger=logger, aiohttp_session=aiohttp_session
)
elif isinstance(synthesizer_config, AzureSynthesizerConfig):
return AzureSynthesizer(synthesizer_config, logger=logger)
return AzureSynthesizer(
synthesizer_config, logger=logger, aiohttp_session=aiohttp_session
)
elif isinstance(synthesizer_config, ElevenLabsSynthesizerConfig):
return ElevenLabsSynthesizer(synthesizer_config, logger=logger)
return ElevenLabsSynthesizer(
synthesizer_config, logger=logger, aiohttp_session=aiohttp_session
)
elif isinstance(synthesizer_config, PlayHtSynthesizerConfig):
return PlayHtSynthesizer(synthesizer_config, logger=logger)
return PlayHtSynthesizer(
synthesizer_config, logger=logger, aiohttp_session=aiohttp_session
)
elif isinstance(synthesizer_config, RimeSynthesizerConfig):
return RimeSynthesizer(synthesizer_config, logger=logger)
return RimeSynthesizer(
synthesizer_config, logger=logger, aiohttp_session=aiohttp_session
)
elif isinstance(synthesizer_config, GTTSSynthesizerConfig):
return GTTSSynthesizer(synthesizer_config, logger=logger)
return GTTSSynthesizer(
synthesizer_config, logger=logger, aiohttp_session=aiohttp_session
)
elif isinstance(synthesizer_config, StreamElementsSynthesizerConfig):
return StreamElementsSynthesizer(synthesizer_config, logger=logger)
return StreamElementsSynthesizer(
synthesizer_config, logger=logger, aiohttp_session=aiohttp_session
)
elif isinstance(synthesizer_config, CoquiTTSSynthesizerConfig):
return CoquiTTSSynthesizer(synthesizer_config, logger=logger)
return CoquiTTSSynthesizer(
synthesizer_config, logger=logger, aiohttp_session=aiohttp_session
)
else:
raise Exception("Invalid synthesizer config")
5 changes: 3 additions & 2 deletions vocode/streaming/synthesizer/google_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import wave
from typing import Any, Optional
import aiohttp

from vocode import getenv

Expand All @@ -24,13 +25,13 @@


class GoogleSynthesizer(BaseSynthesizer[GoogleSynthesizerConfig]):

def __init__(
self,
synthesizer_config: GoogleSynthesizerConfig,
logger: Optional[logging.Logger] = None,
aiohttp_session: Optional[aiohttp.ClientSession] = None,
):
super().__init__(synthesizer_config)
super().__init__(synthesizer_config, aiohttp_session)

from google.cloud import texttospeech_v1beta1 as tts
import google.auth
Expand Down
4 changes: 3 additions & 1 deletion vocode/streaming/synthesizer/gtts_synthesizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging
import aiohttp
from pydub import AudioSegment
from typing import Optional
from io import BytesIO
Expand All @@ -21,8 +22,9 @@ def __init__(
self,
synthesizer_config: GTTSSynthesizerConfig,
logger: Optional[logging.Logger] = None,
aiohttp_session: Optional[aiohttp.ClientSession] = None,
):
super().__init__(synthesizer_config)
super().__init__(synthesizer_config, aiohttp_session)

from gtts import gTTS

Expand Down
3 changes: 2 additions & 1 deletion vocode/streaming/synthesizer/play_ht_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ def __init__(
api_key: Optional[str] = None,
user_id: Optional[str] = None,
logger: Optional[logging.Logger] = None,
aiohttp_session: Optional[aiohttp.ClientSession] = None,
):
super().__init__(synthesizer_config)
super().__init__(synthesizer_config, aiohttp_session)
self.synthesizer_config = synthesizer_config
self.api_key = api_key or getenv("PLAY_HT_API_KEY")
self.user_id = user_id or getenv("PLAY_HT_USER_ID")
Expand Down
3 changes: 2 additions & 1 deletion vocode/streaming/synthesizer/rime_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def __init__(
self,
synthesizer_config: RimeSynthesizerConfig,
logger: Optional[logging.Logger] = None,
aiohttp_session: Optional[aiohttp.ClientSession] = None,
):
super().__init__(synthesizer_config)
super().__init__(synthesizer_config, aiohttp_session)
self.api_key = getenv("RIME_API_KEY")
self.speaker = synthesizer_config.speaker
self.sampling_rate = synthesizer_config.sampling_rate
Expand Down
3 changes: 2 additions & 1 deletion vocode/streaming/synthesizer/stream_elements_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ def __init__(
self,
synthesizer_config: StreamElementsSynthesizerConfig,
logger: Optional[logging.Logger] = None,
aiohttp_session: Optional[aiohttp.ClientSession] = None,
):
super().__init__(synthesizer_config)
super().__init__(synthesizer_config, aiohttp_session)
self.voice = synthesizer_config.voice

async def create_speech(
Expand Down

0 comments on commit 8fb11e5

Please sign in to comment.