Skip to content

Commit

Permalink
use faster whisper
Browse files Browse the repository at this point in the history
  • Loading branch information
Josh-XT committed Jul 5, 2024
1 parent c15f840 commit 9ca4e09
Showing 1 changed file with 51 additions and 84 deletions.
135 changes: 51 additions & 84 deletions examples/Listen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,8 @@
import uuid
from io import BytesIO
from datetime import datetime
import logging

try:
import requests
except ImportError:
subprocess.check_call([sys.executable, "-m", "pip", "install", "requests"])
import requests
try:
import pyaudio
except ImportError:
Expand All @@ -34,6 +30,11 @@
subprocess.check_call([sys.executable, "-m", "pip", "install", "agixtsdk"])
from agixtsdk import AGiXTSDK

try:
from faster_whisper import WhisperModel
except ImportError:
subprocess.check_call([sys.executable, "-m", "pip", "install", "faster-whisper"])
from faster_whisper import WhisperModel

audio = pyaudio.PyAudio()

Expand All @@ -58,90 +59,34 @@ def __init__(
}
)
self.conversation_name = datetime.now().strftime("%Y-%m-%d")
self.w = None
if whisper_model != "":
try:
from whisper_cpp import Whisper
except ImportError:
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "whisper-cpp-pybind"]
)
try:
from whisper_cpp import Whisper
except:
whisper_model = ""
if whisper_model != "":
whisper_model = whisper_model.lower()
if whisper_model not in [
"tiny",
"tiny.en",
"base",
"base.en",
"small",
"small.en",
"medium",
"medium.en",
"large",
"large-v1",
]:
whisper_model = "base.en"
os.makedirs(
os.path.join(os.getcwd(), "models", "whispercpp"), exist_ok=True
)
model_path = os.path.join(
os.getcwd(), "models", "whispercpp", f"ggml-{whisper_model}.bin"
)
if not os.path.exists(model_path):
r = requests.get(
f"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-{whisper_model}.bin",
allow_redirects=True,
)
open(model_path, "wb").write(r.content)
self.w = Whisper(model_path=model_path)

def process_audio_data(self, frames, rms_threshold=500):
audio_data = b"".join(frames)
audio_np = np.frombuffer(audio_data, dtype=np.int16)
rms = np.sqrt(np.mean(audio_np**2))
if rms > rms_threshold:
buffer = BytesIO()
with wave.open(buffer, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(audio.get_sample_size(pyaudio.paInt16))
wf.setframerate(16000)
wf.writeframes(b"".join(frames))
wav_buffer = buffer.getvalue()
base64_audio = base64.b64encode(wav_buffer).decode()
thread = threading.Thread(
target=self.transcribe_audio,
args=(base64_audio),
)
thread.start()
self.TRANSCRIPTION_MODEL = whisper_model

def transcribe_audio(self, base64_audio):
if self.w:
filename = f"{uuid.uuid4().hex}.wav"
file_path = os.path.join(os.getcwd(), filename)
if not os.path.exists(file_path):
raise RuntimeError(f"Failed to load audio: {filename} does not exist.")
self.w.transcribe(file_path)
transcribed_text = self.w.output()
os.remove(os.path.join(os.getcwd(), filename))
else:
transcribed_text = self.sdk.execute_command(
agent_name=self.agent_name,
command_name="Transcribe WAV Audio",
command_args={"base64_audio": base64_audio},
conversation_name="AGiXT Terminal",
)
transcribed_text = transcribed_text.replace("[BLANK_AUDIO]", "")
def transcribe_audio(
self,
audio_path,
translate=False,
):
self.w = WhisperModel(
self.TRANSCRIPTION_MODEL, download_root="models", device="cpu"
)
segments, _ = self.w.transcribe(
audio_path,
task="transcribe" if not translate else "translate",
vad_filter=True,
vad_parameters=dict(min_silence_duration_ms=500),
)
segments = list(segments)
user_input = ""
for segment in segments:
user_input += segment.text
logging.info(f"[STT] Transcribed User Input: {user_input}")
for wake_word, wake_function in self.wake_functions.items():
if wake_word.lower() in transcribed_text.lower():
if wake_word.lower() in user_input.lower():
print("Wake word detected! Executing wake function...")
if wake_function:
response = wake_function(transcribed_text)
response = wake_function(user_input)
else:
response = self.voice_chat(text=transcribed_text)
response = self.voice_chat(text=user_input)
if response:
tts_response = self.sdk.execute_command(
agent_name=self.agent_name,
Expand All @@ -162,6 +107,28 @@ def transcribe_audio(self, base64_audio):
stream.write(generated_audio)
stream.stop_stream()
stream.close()
return user_input

def process_audio_data(self, frames, rms_threshold=500):
audio_data = b"".join(frames)
audio_np = np.frombuffer(audio_data, dtype=np.int16)
rms = np.sqrt(np.mean(audio_np**2))
if rms > rms_threshold:
buffer = BytesIO()
with wave.open(buffer, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(audio.get_sample_size(pyaudio.paInt16))
wf.setframerate(16000)
wf.writeframes(b"".join(frames))
wav_buffer = buffer.getvalue()
file_path = os.path.join(os.getcwd(), f"{uuid.uuid4().hex}.wav")
with open(file_path, "wb") as f:
f.write(wav_buffer)
thread = threading.Thread(
target=self.transcribe_audio,
args=(file_path, False),
)
thread.start()

def listen(self):
print("Listening for wake word...")
Expand Down

0 comments on commit 9ca4e09

Please sign in to comment.