Skip to content

Commit

Permalink
Merge pull request #1128 from golbin/whisper-api
Browse files Browse the repository at this point in the history
Add Whisper STT service using OpenAI API
  • Loading branch information
markbackman authored Feb 8, 2025
2 parents 0180619 + 5989e1e commit d678619
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.openai import OpenAILLMService, OpenAITTSService
from pipecat.services.openai import OpenAILLMService, OpenAISTTService, OpenAITTSService
from pipecat.transports.services.daily import DailyParams, DailyTransport

load_dotenv(override=True)
Expand All @@ -37,12 +37,22 @@ async def main():
"Respond bot",
DailyParams(
audio_out_enabled=True,
transcription_enabled=True,
audio_out_sample_rate=24000,
transcription_enabled=False,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
vad_audio_passthrough=True,
),
)

# You can use the OpenAI compatible API like Groq.
# stt = OpenAISTTService(
# base_url="https://api.groq.com/openai/v1",
# api_key="gsk_***",
# model="whisper-large-v3",
# )
stt = OpenAISTTService(api_key=os.getenv("OPENAI_API_KEY"), model="whisper-1")

tts = OpenAITTSService(api_key=os.getenv("OPENAI_API_KEY"), voice="alloy")

llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
Expand All @@ -60,6 +70,7 @@ async def main():
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt, # STT
context_aggregator.user(), # User responses
llm, # LLM
tts, # TTS
Expand Down
64 changes: 63 additions & 1 deletion src/pipecat/services/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
OpenAILLMContextAssistantTimestampFrame,
StartFrame,
StartInterruptionFrame,
TranscriptionFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
Expand All @@ -48,7 +49,12 @@
OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import ImageGenService, LLMService, TTSService
from pipecat.services.ai_services import (
ImageGenService,
LLMService,
SegmentedSTTService,
TTSService,
)
from pipecat.utils.time import time_now_iso8601

try:
Expand All @@ -59,6 +65,7 @@
BadRequestError,
DefaultAsyncHttpxClient,
)
from openai.types.audio import Transcription
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
Expand Down Expand Up @@ -391,6 +398,61 @@ async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
yield frame


class OpenAISTTService(SegmentedSTTService):
"""OpenAI Speech-to-Text (STT) service.
This service uses OpenAI's Whisper API to convert audio to text.
Args:
model: Whisper model to use. Defaults to "whisper-1".
api_key: OpenAI API key. Defaults to None.
base_url: API base URL. Defaults to None.
**kwargs: Additional arguments passed to SegmentedSTTService.
"""

def __init__(
self,
*,
model: str = "whisper-1",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
self.set_model_name(model)
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)

async def set_model(self, model: str):
self.set_model_name(model)

def can_generate_metrics(self) -> bool:
return True

async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
try:
await self.start_processing_metrics()
await self.start_ttfb_metrics()

response: Transcription = await self._client.audio.transcriptions.create(
file=("audio.wav", audio, "audio/wav"), model=self.model_name
)

await self.stop_ttfb_metrics()
await self.stop_processing_metrics()

text = response.text.strip()

if text:
logger.debug(f"Transcription: [{text}]")
yield TranscriptionFrame(text, "", time_now_iso8601())
else:
logger.warning("Received empty transcription from API")

except Exception as e:
logger.exception(f"Exception during transcription: {e}")
yield ErrorFrame(f"Error during transcription: {str(e)}")


class OpenAITTSService(TTSService):
"""OpenAI Text-to-Speech service that generates audio from text.
Expand Down

0 comments on commit d678619

Please sign in to comment.