Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Whisper STT service using OpenAI API #1128

Merged
merged 4 commits into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.openai import OpenAILLMService, OpenAITTSService
from pipecat.services.openai import OpenAILLMService, OpenAISTTService, OpenAITTSService
from pipecat.transports.services.daily import DailyParams, DailyTransport

load_dotenv(override=True)
Expand All @@ -37,12 +37,22 @@ async def main():
"Respond bot",
DailyParams(
audio_out_enabled=True,
transcription_enabled=True,
audio_out_sample_rate=24000,
transcription_enabled=False,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
vad_audio_passthrough=True,
),
)

# You can use the OpenAI compatible API like Groq.
# stt = OpenAISTTService(
# base_url="https://api.groq.com/openai/v1",
# api_key="gsk_***",
# model="whisper-large-v3",
# )
stt = OpenAISTTService(api_key=os.getenv("OPENAI_API_KEY"), model="whisper-1")

tts = OpenAITTSService(api_key=os.getenv("OPENAI_API_KEY"), voice="alloy")

llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
Expand All @@ -60,6 +70,7 @@ async def main():
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt, # STT
context_aggregator.user(), # User responses
llm, # LLM
tts, # TTS
Expand Down
64 changes: 63 additions & 1 deletion src/pipecat/services/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
OpenAILLMContextAssistantTimestampFrame,
StartFrame,
StartInterruptionFrame,
TranscriptionFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
Expand All @@ -48,7 +49,12 @@
OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import ImageGenService, LLMService, TTSService
from pipecat.services.ai_services import (
ImageGenService,
LLMService,
SegmentedSTTService,
TTSService,
)
from pipecat.utils.time import time_now_iso8601

try:
Expand All @@ -59,6 +65,7 @@
BadRequestError,
DefaultAsyncHttpxClient,
)
from openai.types.audio import Transcription
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
Expand Down Expand Up @@ -391,6 +398,61 @@ async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
yield frame


class OpenAISTTService(SegmentedSTTService):
"""OpenAI Speech-to-Text (STT) service.

This service uses OpenAI's Whisper API to convert audio to text.

Args:
model: Whisper model to use. Defaults to "whisper-1".
api_key: OpenAI API key. Defaults to None.
base_url: API base URL. Defaults to None.
**kwargs: Additional arguments passed to SegmentedSTTService.
"""

def __init__(
self,
*,
model: str = "whisper-1",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
self.set_model_name(model)
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)

async def set_model(self, model: str):
self.set_model_name(model)

def can_generate_metrics(self) -> bool:
return True

async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
try:
await self.start_processing_metrics()
await self.start_ttfb_metrics()

response: Transcription = await self._client.audio.transcriptions.create(
file=("audio.wav", audio, "audio/wav"), model=self.model_name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this save a file to disk? We should not add that by default, probably an argument to the constructor, like save_audio_file: Optional[str]: None.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooops, duh! Sorry, I misread the code!

)

await self.stop_ttfb_metrics()
await self.stop_processing_metrics()

text = response.text.strip()

if text:
logger.debug(f"Transcription: [{text}]")
yield TranscriptionFrame(text, "", time_now_iso8601())
else:
logger.warning("Received empty transcription from API")

except Exception as e:
logger.exception(f"Exception during transcription: {e}")
yield ErrorFrame(f"Error during transcription: {str(e)}")


class OpenAITTSService(TTSService):
"""OpenAI Text-to-Speech service that generates audio from text.

Expand Down