Skip to content

Commit

Permalink
Streaming callbacks (#85)
Browse files Browse the repository at this point in the history
* add warning for callback_method for streaming

* add callback parameter to WebsocketBuilder

* fix clippy

* add method on websocket to get request id

* documentation and error fix
  • Loading branch information
bd-g authored Aug 22, 2024
1 parent 896359b commit 5a1420c
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 2 deletions.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ name = "simple_stream"
path = "examples/transcription/websocket/simple_stream.rs"
required-features = ["listen"]

[[example]]
name = "callback_stream"
path = "examples/transcription/websocket/callback_stream.rs"
required-features = ["listen"]

[[example]]
name = "microphone_stream"
path = "examples/transcription/websocket/microphone_stream.rs"
Expand Down
54 changes: 54 additions & 0 deletions examples/transcription/websocket/callback_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use std::env;
use std::time::Duration;

use futures::stream::StreamExt;

use deepgram::{
common::options::{Encoding, Endpointing, Language, Options},
Deepgram, DeepgramError,
};

static PATH_TO_FILE: &str = "examples/audio/bueller.wav";
static AUDIO_CHUNK_SIZE: usize = 3174;
static FRAME_DELAY: Duration = Duration::from_millis(16);

#[tokio::main]
async fn main() -> Result<(), DeepgramError> {
let deepgram_api_key =
env::var("DEEPGRAM_API_KEY").expect("DEEPGRAM_API_KEY environmental variable");

let dg_client = Deepgram::new(&deepgram_api_key)?;

let options = Options::builder()
.smart_format(true)
.language(Language::en_US)
.build();

let callback_url = env::var("DEEPGRAM_CALLBACK_URL")
.expect("DEEPGRAM_CALLBACK_URL environmental variable")
.parse()
.expect("DEEPGRAM_CALLBACK_URL not a valid URL");

let mut results = dg_client
.transcription()
.stream_request_with_options(options)
.keep_alive()
.encoding(Encoding::Linear16)
.sample_rate(44100)
.channels(2)
.endpointing(Endpointing::CustomDurationMs(300))
.interim_results(true)
.utterance_end_ms(1000)
.vad_events(true)
.no_delay(true)
.callback(callback_url)
.file(PATH_TO_FILE, AUDIO_CHUNK_SIZE, FRAME_DELAY)
.await?;

println!("Deepgram Request ID: {}", results.request_id());
while let Some(result) = results.next().await {
println!("got: {:?}", result);
}

Ok(())
}
1 change: 1 addition & 0 deletions examples/transcription/websocket/microphone_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ async fn main() -> Result<(), DeepgramError> {
.stream(microphone_as_stream())
.await?;

println!("Deepgram Request ID: {}", results.request_id());
while let Some(result) = results.next().await {
println!("got: {:?}", result);
}
Expand Down
1 change: 1 addition & 0 deletions examples/transcription/websocket/simple_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ async fn main() -> Result<(), DeepgramError> {
.file(PATH_TO_FILE, AUDIO_CHUNK_SIZE, FRAME_DELAY)
.await?;

println!("Deepgram Request ID: {}", results.request_id());
while let Some(result) = results.next().await {
println!("got: {:?}", result);
}
Expand Down
5 changes: 5 additions & 0 deletions src/common/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1887,7 +1887,12 @@ impl OptionsBuilder {
///
/// See the [Deepgram Callback Method feature docs][docs] for more info.
///
/// Note that modifying the callback method is only available for pre-recorded audio.
/// See the [Deepgram Callback feature docs for streaming][streaming-docs] for details
/// on streaming callbacks.
///
/// [docs]: https://developers.deepgram.com/docs/callback#pre-recorded-audio
/// [streaming-docs]: https://developers.deepgram.com/docs/callback#streaming-audio
///
/// # Examples
///
Expand Down
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ pub enum DeepgramError {
/// An unexpected error occurred in the client
#[error("an unepected error occurred in the deepgram client: {0}")]
InternalClientError(anyhow::Error),

/// A Deepgram API server response was not in the expected format.
#[error("The Deepgram API server response was not in the expected format: {0}")]
UnexpectedServerResponse(anyhow::Error),
}

#[cfg_attr(not(feature = "listen"), allow(unused))]
Expand Down
54 changes: 52 additions & 2 deletions src/listen/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use std::{
time::Duration,
};

use anyhow::anyhow;
use bytes::Bytes;
use futures::{
channel::mpsc::{self, Receiver, Sender},
Expand All @@ -36,6 +37,7 @@ use tungstenite::{
protocol::frame::coding::{Data, OpCode},
};
use url::Url;
use uuid::Uuid;

use self::file_chunker::FileChunker;
use crate::{
Expand All @@ -62,6 +64,7 @@ pub struct WebsocketBuilder<'a> {
vad_events: Option<bool>,
stream_url: Url,
keep_alive: Option<bool>,
callback: Option<Url>,
}

impl Transcription<'_> {
Expand Down Expand Up @@ -143,6 +146,7 @@ impl Transcription<'_> {
vad_events: None,
stream_url: self.listen_stream_url(),
keep_alive: None,
callback: None,
}
}

Expand Down Expand Up @@ -214,6 +218,7 @@ impl<'a> WebsocketBuilder<'a> {
no_delay,
vad_events,
stream_url,
callback,
} = self;

let mut url = stream_url.clone();
Expand Down Expand Up @@ -257,6 +262,9 @@ impl<'a> WebsocketBuilder<'a> {
if let Some(vad_events) = vad_events {
pairs.append_pair("vad_events", &vad_events.to_string());
}
if let Some(callback) = callback {
pairs.append_pair("callback", callback.as_ref());
}
}

Ok(url)
Expand Down Expand Up @@ -315,6 +323,12 @@ impl<'a> WebsocketBuilder<'a> {

self
}

pub fn callback(mut self, callback: Url) -> Self {
self.callback = Some(callback);

self
}
}

impl<'a> WebsocketBuilder<'a> {
Expand Down Expand Up @@ -351,6 +365,7 @@ impl<'a> WebsocketBuilder<'a> {

let (tx, rx) = mpsc::channel(1);
let mut is_done = false;
let request_id = handle.request_id();
tokio::task::spawn(async move {
let mut handle = handle;
let mut tx = tx;
Expand Down Expand Up @@ -421,7 +436,11 @@ impl<'a> WebsocketBuilder<'a> {
}
}
});
Ok(TranscriptionStream { rx, done: false })
Ok(TranscriptionStream {
rx,
done: false,
request_id,
})
}

/// A low level interface to the Deepgram websocket transcription API.
Expand Down Expand Up @@ -628,6 +647,7 @@ impl Deref for Audio {
pub struct WebsocketHandle {
message_tx: Sender<WsMessage>,
response_rx: Receiver<Result<StreamResponse>>,
request_id: Uuid,
}

impl<'a> WebsocketHandle {
Expand All @@ -652,7 +672,21 @@ impl<'a> WebsocketHandle {
builder.body(())?
};

let (ws_stream, _) = tokio_tungstenite::connect_async(request).await?;
let (ws_stream, upgrade_response) = tokio_tungstenite::connect_async(request).await?;

let request_id = upgrade_response
.headers()
.get("dg-request-id")
.ok_or(DeepgramError::UnexpectedServerResponse(anyhow!(
"Websocket upgrade headers missing request ID"
)))?
.to_str()
.ok()
.and_then(|req_header_str| Uuid::parse_str(req_header_str).ok())
.ok_or(DeepgramError::UnexpectedServerResponse(anyhow!(
"Received malformed request ID in websocket upgrade headers"
)))?;

let (message_tx, message_rx) = mpsc::channel(256);
let (response_tx, response_rx) = mpsc::channel(256);

Expand All @@ -670,6 +704,7 @@ impl<'a> WebsocketHandle {
Ok(WebsocketHandle {
message_tx,
response_rx,
request_id,
})
}

Expand Down Expand Up @@ -725,6 +760,10 @@ impl<'a> WebsocketHandle {
// eprintln!("<handle> receiving response: {resp:?}");
resp
}

pub fn request_id(&self) -> Uuid {
self.request_id
}
}

#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize)]
Expand All @@ -741,6 +780,7 @@ pub struct TranscriptionStream {
#[pin]
rx: Receiver<Result<StreamResponse>>,
done: bool,
request_id: Uuid,
}

impl Stream for TranscriptionStream {
Expand All @@ -752,6 +792,16 @@ impl Stream for TranscriptionStream {
}
}

impl TranscriptionStream {
/// Returns the Deepgram request ID for the speech-to-text live request.
///
/// A request ID needs to be provided to Deepgram as part of any support
/// or troubleshooting assistance related to a specific request.
pub fn request_id(&self) -> Uuid {
self.request_id
}
}

mod file_chunker {
use bytes::{Bytes, BytesMut};
use futures::Stream;
Expand Down

0 comments on commit 5a1420c

Please sign in to comment.