Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
`apply_text_normalization` was incorrectly set as a query parameter. It's now
being added as a request parameter.

- Updated all STT and TTS services to use consistent error handling pattern with
`push_error()` method for better pipeline error event integration.

- Fixed an issue where `RimeHttpTTSService` and `PiperTTSService` could generate
incorrectly 16-bit aligned audio frames, potentially leading to internal
errors or static audio.
Expand Down
19 changes: 13 additions & 6 deletions src/pipecat/services/assemblyai/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
Expand Down Expand Up @@ -205,8 +206,9 @@ async def _connect(self):

await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"Failed to connect to AssemblyAI: {e}")
logger.error(f"{self} exception: {e}")
self._connected = False
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
raise

async def _disconnect(self):
Expand All @@ -231,15 +233,17 @@ async def _disconnect(self):
logger.warning("Timed out waiting for termination message from server")

except Exception as e:
logger.warning(f"Error during termination handshake: {e}")
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))

if self._receive_task:
await self.cancel_task(self._receive_task)

await self._websocket.close()

except Exception as e:
logger.error(f"Error during disconnect: {e}")
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))

finally:
self._websocket = None
Expand All @@ -258,11 +262,13 @@ async def _receive_task_handler(self):
except websockets.exceptions.ConnectionClosedOK:
break
except Exception as e:
logger.error(f"Error processing WebSocket message: {e}")
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
break

except Exception as e:
logger.error(f"Fatal error in receive handler: {e}")
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))

def _parse_message(self, message: Dict[str, Any]) -> BaseMessage:
"""Parse a raw message into the appropriate message type."""
Expand Down Expand Up @@ -291,7 +297,8 @@ async def _handle_message(self, message: Dict[str, Any]):
elif isinstance(parsed_message, TerminationMessage):
await self._handle_termination(parsed_message)
except Exception as e:
logger.error(f"Error handling message: {e}")
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))

async def _handle_termination(self, message: TerminationMessage):
"""Handle termination message."""
Expand Down
16 changes: 10 additions & 6 deletions src/pipecat/services/asyncai/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ async def _connect_websocket(self):

await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} initialization error: {e}")
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")

Expand All @@ -250,7 +251,8 @@ async def _disconnect_websocket(self):
logger.debug("Disconnecting from Async")
await self._websocket.close()
except Exception as e:
logger.error(f"{self} error closing websocket: {e}")
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
finally:
self._websocket = None
self._started = False
Expand Down Expand Up @@ -298,7 +300,7 @@ async def _receive_messages(self):
logger.error(f"{self} error: {msg}")
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f"{self} error: {msg['message']}"))
await self.push_error(ErrorFrame(error=f"{self} error: {msg['message']}"))
else:
logger.error(f"{self} error, unknown message type: {msg}")

Expand Down Expand Up @@ -343,14 +345,16 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
await self._get_websocket().send(msg)
await self.start_tts_usage_metrics(text)
except Exception as e:
logger.error(f"{self} error sending message: {e}")
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
yield TTSStoppedFrame()
await self._disconnect()
await self._connect()
return
yield None
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))


class AsyncAIHttpTTSService(TTSService):
Expand Down Expand Up @@ -484,7 +488,7 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
if response.status != 200:
error_text = await response.text()
logger.error(f"Async API error: {error_text}")
await self.push_error(ErrorFrame(f"Async API error: {error_text}"))
await self.push_error(ErrorFrame(error=f"Async API error: {error_text}"))
raise Exception(f"Async API returned status {response.status}: {error_text}")

audio_data = await response.read()
Expand All @@ -501,7 +505,7 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:

except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(f"Error generating TTS: {e}"))
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
finally:
await self.stop_ttfb_metrics()
yield TTSStoppedFrame()
28 changes: 15 additions & 13 deletions src/pipecat/services/aws/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ async def start(self, frame: StartFrame):
return
logger.warning("WebSocket connection not established after connect")
except Exception as e:
logger.error(f"Failed to connect (attempt {retry_count + 1}/{max_retries}): {e}")
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
retry_count += 1
if retry_count < max_retries:
await asyncio.sleep(1) # Wait before retrying
Expand Down Expand Up @@ -181,8 +182,8 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
try:
await self._connect()
except Exception as e:
logger.error(f"Failed to reconnect: {e}")
yield ErrorFrame("Failed to reconnect to AWS Transcribe", fatal=False)
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
return

# Format the audio data according to AWS event stream format
Expand All @@ -199,13 +200,13 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
await self._disconnect()
# Don't yield error here - we'll retry on next frame
except Exception as e:
logger.error(f"Error sending audio: {e}")
yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False)
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
await self._disconnect()

except Exception as e:
logger.error(f"Error in run_stt: {e}")
yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False)
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
await self._disconnect()

async def _connect(self):
Expand Down Expand Up @@ -288,7 +289,8 @@ async def _connect(self):

await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self} Failed to connect to AWS Transcribe: {e}")
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
await self._disconnect()
raise

Expand All @@ -308,7 +310,8 @@ async def _disconnect(self):
await self._ws_client.send(json.dumps(end_stream))
await self._ws_client.close()
except Exception as e:
logger.warning(f"{self} Error closing WebSocket connection: {e}")
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
finally:
self._ws_client = None
await self._call_event_handler("on_disconnected")
Expand Down Expand Up @@ -527,15 +530,14 @@ async def _receive_loop(self):
elif headers.get(":message-type") == "exception":
error_msg = payload.get("Message", "Unknown error")
logger.error(f"{self} Exception from AWS: {error_msg}")
await self.push_frame(
ErrorFrame(f"AWS Transcribe error: {error_msg}", fatal=False)
)
await self.push_frame(ErrorFrame(f"AWS Transcribe error: {error_msg}"))
else:
logger.debug(f"{self} Other message type received: {headers}")
logger.debug(f"{self} Payload: {payload}")
except websockets.exceptions.ConnectionClosed as e:
logger.error(f"{self} WebSocket connection closed in receive loop: {e}")
break
except Exception as e:
logger.error(f"{self} Unexpected error in receive loop: {e}")
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
break
39 changes: 24 additions & 15 deletions src/pipecat/services/azure/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
Expand Down Expand Up @@ -111,13 +112,17 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
audio: Raw audio bytes to process.

Yields:
None - actual transcription frames are pushed via callbacks.
Frame: Either None for successful processing or ErrorFrame on failure.
"""
await self.start_processing_metrics()
await self.start_ttfb_metrics()
if self._audio_stream:
self._audio_stream.write(audio)
yield None
try:
await self.start_processing_metrics()
await self.start_ttfb_metrics()
if self._audio_stream:
self._audio_stream.write(audio)
yield None
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")

async def start(self, frame: StartFrame):
"""Start the speech recognition service.
Expand All @@ -133,17 +138,21 @@ async def start(self, frame: StartFrame):
if self._audio_stream:
return

stream_format = AudioStreamFormat(samples_per_second=self.sample_rate, channels=1)
self._audio_stream = PushAudioInputStream(stream_format)
try:
stream_format = AudioStreamFormat(samples_per_second=self.sample_rate, channels=1)
self._audio_stream = PushAudioInputStream(stream_format)

audio_config = AudioConfig(stream=self._audio_stream)
audio_config = AudioConfig(stream=self._audio_stream)

self._speech_recognizer = SpeechRecognizer(
speech_config=self._speech_config, audio_config=audio_config
)
self._speech_recognizer.recognizing.connect(self._on_handle_recognizing)
self._speech_recognizer.recognized.connect(self._on_handle_recognized)
self._speech_recognizer.start_continuous_recognition_async()
self._speech_recognizer = SpeechRecognizer(
speech_config=self._speech_config, audio_config=audio_config
)
self._speech_recognizer.recognizing.connect(self._on_handle_recognizing)
self._speech_recognizer.recognized.connect(self._on_handle_recognized)
self._speech_recognizer.start_continuous_recognition_async()
except Exception as e:
logger.error(f"{self} exception during initialization: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))

async def stop(self, frame: EndFrame):
"""Stop the speech recognition service.
Expand Down
7 changes: 5 additions & 2 deletions src/pipecat/services/azure/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
if self._speech_synthesizer is None:
error_msg = "Speech synthesizer not initialized."
logger.error(error_msg)
yield ErrorFrame(error_msg)
yield ErrorFrame(error=error_msg)
return

try:
Expand All @@ -355,13 +355,15 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
yield TTSStoppedFrame()

except Exception as e:
logger.error(f"{self} error during synthesis: {e}")
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
yield TTSStoppedFrame()
# Could add reconnection logic here if needed
return

except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))


class AzureHttpTTSService(AzureBaseTTSService):
Expand Down Expand Up @@ -439,3 +441,4 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.warning(f"Speech synthesis canceled: {cancellation_details.reason}")
if cancellation_details.reason == CancellationReason.Error:
logger.error(f"{self} error: {cancellation_details.error_details}")
yield ErrorFrame(error=f"{self} error: {cancellation_details.error_details}")
9 changes: 7 additions & 2 deletions src/pipecat/services/cartesia/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
Expand Down Expand Up @@ -275,7 +276,8 @@ async def _connect_websocket(self):
self._websocket = await websocket_connect(ws_url, additional_headers=headers)
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self}: unable to connect to Cartesia: {e}")
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))

async def _disconnect_websocket(self):
try:
Expand All @@ -284,6 +286,7 @@ async def _disconnect_websocket(self):
await self._websocket.close()
except Exception as e:
logger.error(f"{self} error closing websocket: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
finally:
self._websocket = None
await self._call_event_handler("on_disconnected")
Expand Down Expand Up @@ -315,7 +318,9 @@ async def _process_response(self, data):
await self._on_transcript(data)

elif data["type"] == "error":
logger.error(f"Cartesia error: {data.get('message', 'Unknown error')}")
error_msg = data.get("message", "Unknown error")
logger.error(f"Cartesia error: {error_msg}")
await self.push_error(ErrorFrame(error=error_msg))

@traced_stt
async def _handle_transcription(
Expand Down
Loading