-
Notifications
You must be signed in to change notification settings - Fork 893
/
Copy pathstt.rs
153 lines (129 loc) · 7.24 KB
/
stt.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
use futures::future::join_all;
use screenpipe_audio::pyannote::embedding::EmbeddingExtractor;
use screenpipe_audio::pyannote::identify::EmbeddingManager;
use screenpipe_audio::stt::{prepare_segments, stt};
use screenpipe_audio::vad_engine::{SileroVad, VadEngine};
use screenpipe_audio::whisper::WhisperModel;
use screenpipe_audio::{AudioInput, AudioTranscriptionEngine};
use screenpipe_core::Language;
use std::path::PathBuf;
use std::sync::Arc;
use strsim::levenshtein;
use tokio::sync::Mutex;
use tracing::debug;
#[tokio::main]
async fn main() {
// Initialize tracing
// tracing_subscriber::fmt()
// .with_max_level(tracing::Level::DEBUG)
// .init();
debug!("starting transcription accuracy test");
// Setup
let test_cases = vec![
(
"test_data/accuracy1.wav",
r#"yo louis, here's the tldr of that mind-blowing meeting. bob's cat walked across his keyboard 3 times. productivity increased by 200%. sarah's virtual background glitched, revealing she was actually on a beach. no one noticed. you successfully pretended to be engaged while scrolling twitter. achievement unlocked! 7 people said "you're on mute" in perfect synchronization. new world record. meeting could've been an email. shocking. key takeaway: we're all living in a simulation, and the devs are laughing. peace out, llama3.2:3b-instruct-q4_k_m"#,
),
(
"test_data/accuracy2.wav",
r#"bro - got some good stuff from screenpipe here's the lowdown on your day, you productivity ninja: absolutely demolished that 2-hour coding sesh on the new feature. the keyboard is still smoking, bro! crushed 3 client calls like a boss. they're probably writing love letters to you as we speak, make sure to close john tomorrow 8.00 am according to our notes, let the cash flow in! spent 45 mins on slack. 90% memes, 10% actual work. perfectly balanced, as all things should bewatched a rust tutorial. way to flex those brain muscles, you nerd! overall, you're killing it! 80% of your time on high-value tasks. the other 20%? probably spent admiring your own reflection, you handsome devil. ps: seriously, quit tiktok. your fbi agent is getting bored watching you scroll endlessly. what's the plan for tomorrow? more coding? more memes? world domination? generated by your screenpipe ai assistant (who's definitely not planning to take over the world... yet)"#,
),
(
"test_data/accuracy3.wav",
r#"again, screenpipe allows you to get meeting summaries, locally, without leaking data to openai, with any apps, like whatsapp, meet, zoom, etc. and it's open source at github.com/mediar-ai/screenpipe"#,
),
(
"test_data/accuracy4.wav",
r#"eventually but, i mean, i feel like but, i mean, first, i mean, you think your your vision smart will be interesting because, yeah, you install once. you pay us, you install once. that that yours. so, basically, all the time microsoft explained, you know, ms office, long time ago, you just buy the the the software that you can using there forever unless you wanna you wanna update upgrade is the better version. right? so it's a little bit, you know"#,
),
(
"test_data/accuracy5.wav",
r#"thank you. yeah. so i cannot they they took it, refresh because of my one set top top time. and, also, second thing is, your byte was stolen. by the time?"#,
),
// Add more test cases as needed
];
let whisper_model = Arc::new(Mutex::new(
WhisperModel::new(&AudioTranscriptionEngine::WhisperLargeV3Turbo).unwrap(),
));
let vad_engine: Arc<Mutex<Box<dyn VadEngine + Send>>> =
Arc::new(Mutex::new(Box::new(SileroVad::new().await.unwrap())));
let project_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let segmentation_model_path = project_dir
.join("models")
.join("pyannote")
.join("segmentation-3.0.onnx");
let embedding_model_path = project_dir
.join("models")
.join("pyannote")
.join("wespeaker_en_voxceleb_CAM++.onnx");
let embedding_extractor = Arc::new(std::sync::Mutex::new(
EmbeddingExtractor::new(embedding_model_path.to_str().unwrap()).unwrap(),
));
let embedding_manager = EmbeddingManager::new(usize::MAX);
let mut tasks = Vec::new();
for (audio_file, expected_transcription) in test_cases {
let whisper_model = Arc::clone(&whisper_model);
let vad_engine = Arc::clone(&vad_engine);
let segmentation_model_path = segmentation_model_path.clone();
let embedding_extractor = Arc::clone(&embedding_extractor);
let embedding_manager = embedding_manager.clone();
let task = tokio::spawn(async move {
let audio_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(audio_file);
let audio_data =
screenpipe_audio::pcm_decode(&audio_path).expect("Failed to decode audio file");
let audio_input = AudioInput {
data: Arc::new(audio_data.0),
sample_rate: 44100, // hardcoded based on test data sample rate
channels: 1,
device: Arc::new(screenpipe_audio::default_input_device().unwrap()),
};
let mut segments = prepare_segments(
&audio_input.data,
vad_engine.clone(),
&segmentation_model_path,
embedding_manager,
embedding_extractor,
&audio_input.device.to_string(),
)
.await
.unwrap();
let mut whisper_model_guard = whisper_model.lock().await;
let mut transcription = String::new();
while let Some(segment) = segments.recv().await {
let transcript = stt(
&segment.samples,
audio_input.sample_rate,
&audio_input.device.to_string(),
&mut whisper_model_guard,
Arc::new(AudioTranscriptionEngine::WhisperLargeV3Turbo),
None,
vec![Language::English],
)
.await
.unwrap();
transcription.push_str(&transcript);
}
drop(whisper_model_guard);
let distance = levenshtein(expected_transcription, &transcription.to_lowercase());
let accuracy = 1.0 - (distance as f64 / expected_transcription.len() as f64);
(audio_file, expected_transcription, transcription, accuracy)
});
tasks.push(task);
}
let results = join_all(tasks).await;
let mut total_accuracy = 0.0;
let mut total_tests = 0;
for result in results {
let (audio_file, expected_transcription, transcription, accuracy) = result.unwrap();
println!("file: {}", audio_file);
println!("expected: {}", expected_transcription);
println!("actual: {}", transcription);
println!("accuracy: {:.2}%", accuracy * 100.0);
// println!();
total_accuracy += accuracy;
total_tests += 1;
}
let average_accuracy = total_accuracy / total_tests as f64;
println!("average accuracy: {:.2}%", average_accuracy * 100.0);
assert!(average_accuracy > 0.55, "average accuracy is below 55%");
}