Skip to content

Commit

Permalink
update whisper stream
Browse files Browse the repository at this point in the history
  • Loading branch information
lvyufeng committed Nov 4, 2024
1 parent 51e0f25 commit 2326a30
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 1,052 deletions.
2 changes: 1 addition & 1 deletion .github/pylint.conf
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ persistent=yes

# Minimum Python version to use for version dependent checks. Will default to
# the version used to run pylint.
py-version=3.9
py-version=3.11

# Discover python modules and packages in the file system subtree.
recursive=no
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/ci_pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
python: [3.8]
python: [3.9]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
Expand Down Expand Up @@ -170,10 +170,10 @@ jobs:
if: github.event_name == 'push' && github.repository_owner == 'mindspore-lab'
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.8
- name: Set up Python 3.9
uses: actions/setup-python@v4
with:
python-version: 3.8
python-version: 3.9
- uses: "lvyufeng/action-kaggle-gpu-test@latest"
with:
kaggle_username: "${{ secrets.KAGGLE_USERNAME }}"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/doc_rst_check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.8
python-version: 3.9
- name: Install dependencies
run: |
python -m pip install --upgrade pip==24.0
Expand Down
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,7 @@ data*/

fusion_result.json
aclinit.json
xiyouji.txt
xiyouji.txt
*.safetensors
*.jit
flagged/
79 changes: 79 additions & 0 deletions llm/inference/whisper/app_realtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import gradio as gr
import time
import numpy as np
import mindspore
from mindnlp.transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

ms_dtype = mindspore.float16
MODEL_NAME = "openai/whisper-large-v3-turbo"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_NAME, ms_dtype=ms_dtype, low_cpu_mem_usage=True
)

processor = AutoProcessor.from_pretrained(MODEL_NAME)

pipe = pipeline(
task="automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
ms_dtype=ms_dtype,
)

prompt = "以下是普通话的句子。" # must have periods
prompt_ids = processor.get_prompt_ids(prompt, return_tensors="ms")
generate_kwargs = {"prompt_ids": prompt_ids}

def transcribe(inputs, previous_transcription):
start_time = time.time()
try:
sample_rate, audio_data = inputs
audio_data = audio_data.astype(np.float32)
audio_data /= np.max(np.abs(audio_data))

transcription = pipe({"sampling_rate": sample_rate, "raw": audio_data}, generate_kwargs=generate_kwargs)["text"]
previous_transcription += transcription

end_time = time.time()
latency = end_time - start_time
return previous_transcription, f"{latency:.2f}"
except Exception as e:
print(f"Error during Transcription: {e}")
return previous_transcription, "Error"


def clear():
return ""

with gr.Blocks() as microphone:
with gr.Column():
gr.Markdown(f"# Realtime Whisper Large V3 Turbo: \n Transcribe Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.\n Note: The first token takes about 5 seconds. After that, it works flawlessly.")
with gr.Row():
input_audio_microphone = gr.Audio(streaming=True)
output = gr.Textbox(label="Transcription", value="")
latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0", scale=0)
with gr.Row():
clear_button = gr.Button("Clear Output")

input_audio_microphone.stream(transcribe, [input_audio_microphone, output], [output, latency_textbox], time_limit=45, stream_every=2, concurrency_limit=None)
clear_button.click(clear, outputs=[output])

with gr.Blocks() as file:
with gr.Column():
gr.Markdown(f"# Realtime Whisper Large V3 Turbo: \n Transcribe Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.\n Note: The first token takes about 5 seconds. After that, it works flawlessly.")
with gr.Row():
input_audio_microphone = gr.Audio(sources="upload", type="numpy")
output = gr.Textbox(label="Transcription", value="")
latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0", scale=0)
with gr.Row():
submit_button = gr.Button("Submit")
clear_button = gr.Button("Clear Output")

submit_button.click(transcribe, [input_audio_microphone, output], [output, latency_textbox], concurrency_limit=None)
clear_button.click(clear, outputs=[output])

with gr.Blocks(theme=gr.themes.Ocean()) as demo:
gr.TabbedInterface([microphone, file], ["Microphone", "Transcribe from file"])

demo.launch()
80 changes: 62 additions & 18 deletions llm/inference/whisper/app_stream.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,94 @@
import queue
import threading
import time

import gradio as gr
import numpy as np
import mindspore
from mindnlp.transformers import pipeline

from mindspore.dataset.audio import Resample
from mindnlp.transformers import pipeline, AutoProcessor
from silero_vad_mindspore import load

MODEL_NAME = "openai/whisper-large-v3"
BATCH_SIZE = 8
FILE_LIMIT_MB = 1000
YT_LENGTH_LIMIT_S = 3600 # limit to 1 hour YouTube files
THRESH_HOLD = 0.5

stream_queue = queue.Queue()

vad_model = load('silero_vad_v4')

processor = AutoProcessor.from_pretrained(MODEL_NAME)
pipe = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
ms_dtype=mindspore.float16
)

prompt = "以下是普通话的句子。" # must have periods
prompt_ids = processor.get_prompt_ids(prompt, return_tensors="ms")

text = ""
silence_count = 0

resample = Resample(48000, 16000)
generate_kwargs = {"language": "zh", "task": "transcribe", "prompt_ids": prompt_ids}
# "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), "no_speech_threshold": 0.5, "logprob_threshold": -1.0}

# warm up
random_sample = np.random.randn(16000).astype(np.float32)
vad_model(mindspore.tensor(random_sample), 16000)
pipe(random_sample, generate_kwargs=generate_kwargs, return_timestamps='word')

def pipeline_consumer():
global text
while True:
chunk = stream_queue.get()
# print(speech_score)
genreated_text = pipe(chunk, generate_kwargs=generate_kwargs, return_timestamps='word')["text"]
text += genreated_text + '\n'

stream_queue.task_done()

if stream_queue.empty() and stream_queue.unfinished_tasks == 0:
time.sleep(1)


def transcribe(stream, new_chunk):
generate_kwargs = {"language": "zh", "task": "transcribe"}
global text

sr, y = new_chunk

y = y.astype(np.float32)
y /= np.max(np.abs(y))
print(y)
# print('sample shape:', y.shape)
speech_score = vad_model(mindspore.tensor(y), sr)
speech_score = speech_score.item()
print('speech socre', speech_score)

if stream is not None:
stream = np.concatenate([stream, y])
else:
stream = y
if speech_score > 0.5:
if stream is not None:
if stream.shape < y.shape or (stream[-len(y):] - y).sum() != 0:
stream = np.concatenate([stream, y])
else:
stream = y

if stream.shape[0] < (3 * 48000):
return stream, None

text = pipe({"sampling_rate": sr, "raw": y}, generate_kwargs=generate_kwargs)["text"]

if str(text).endswith((".", "。", '?', "?", '!', "!", ":", ":")):
if stream is not None and stream.shape[0] >= (48000 * 5): # 5s if continue talk
print('stream shape:', stream.shape)
stream_queue.put({"sampling_rate": sr, "raw": stream})
stream = None

return stream, text # type: ignore

input_audio = gr.Audio(sources=["microphone"], streaming=True)
demo = gr.Interface(
transcribe,
["state", gr.Audio(sources=["microphone"], streaming=True)],
["state", input_audio],
["state", "text"],
live=True,
)

if __name__ == "__main__":
c = threading.Thread(target=pipeline_consumer)
c.start()
demo.launch()
2 changes: 1 addition & 1 deletion mindnlp/core/nn/modules/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def _ntuple(n):
def parse(x):
if isinstance(x, collections.Iterable):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
Expand Down
Loading

0 comments on commit 2326a30

Please sign in to comment.