-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f648a2e
commit 4624f89
Showing
4 changed files
with
156 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
""" | ||
MLX Transcribe: A tool for transcribing audio files using MLX Whisper | ||
Requirements: | ||
1. ffmpeg - Install using: | ||
- macOS: `brew install ffmpeg` | ||
- Ubuntu: `sudo apt-get install ffmpeg` | ||
- Windows: Download from https://ffmpeg.org/download.html | ||
2. mlx-whisper library: | ||
pip install mlx-whisper | ||
Example Usage: | ||
- Place your audio files in the 'storage/audio' directory | ||
Eg: download https://www.ted.com/talks/reid_hoffman_and_kevin_scott_the_evolution_of_ai_and_how_it_will_impact_human_creativity | ||
- Run this script to transcribe audio files | ||
- Supports various audio formats (mp3, mp4, wav, etc.) | ||
""" | ||
|
||
from pathlib import Path | ||
from phi.agent import Agent | ||
from phi.model.openai import OpenAIChat | ||
from phi.tools.mlx_transcribe import MLXTranscribe | ||
|
||
# Get audio files from storage/audio directory | ||
phidata_root_dir = Path(__file__).parent.parent.parent.resolve() | ||
audio_storage_dir = phidata_root_dir.joinpath("storage/audio") | ||
if not audio_storage_dir.exists(): | ||
audio_storage_dir.mkdir(exist_ok=True, parents=True) | ||
|
||
agent = Agent( | ||
name="Transcription Agent", | ||
model=OpenAIChat(id="gpt-4o"), | ||
tools=[MLXTranscribe(base_dir=audio_storage_dir)], | ||
instructions=[ | ||
"To transcribe an audio file, use the `transcribe` tool with the name of the audio file as the argument.", | ||
"You can find all available audio files using the `read_files` tool.", | ||
], | ||
markdown=True, | ||
) | ||
|
||
agent.print_response("Summarize the reid hoffman ted talk, split into sections", stream=True) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,137 @@ | ||
""" | ||
MLX Transcribe Tools will need ffmpeg installed to work. | ||
Install ffmpeg using `brew install ffmpeg` on macOS. | ||
MLX Transcribe Tools - Audio Transcription using Apple's MLX Framework | ||
Requirements: | ||
- ffmpeg: Required for audio processing | ||
macOS: brew install ffmpeg | ||
Ubuntu: apt-get install ffmpeg | ||
Windows: Download from https://ffmpeg.org/download.html | ||
- mlx-whisper: Install via pip | ||
pip install mlx-whisper | ||
This module provides tools for transcribing audio files using the MLX Whisper model, | ||
optimized for Apple Silicon processors. It supports various audio formats and | ||
provides high-quality transcription capabilities. | ||
""" | ||
|
||
import json | ||
from pathlib import Path | ||
from typing import Optional, Union, Tuple, List, Dict, Any | ||
|
||
from phi.tools import Toolkit | ||
from phi.utils.log import logger | ||
from typing import Optional | ||
|
||
try: | ||
import mlx_whisper # type: ignore | ||
import mlx_whisper | ||
except ImportError: | ||
raise ImportError("`mlx_whisper` not installed. Please install using `pip install mlx-whisper`") | ||
|
||
|
||
class MLXTranscribeTools(Toolkit): | ||
def __init__(self, file_path: str, path_or_hf_repo: Optional[str] = "mlx-community/whisper-large-v3-turbo"): | ||
super().__init__(name="transcribe") | ||
class MLXTranscribe(Toolkit): | ||
def __init__( | ||
self, | ||
base_dir: Optional[Path] = None, | ||
read_files_in_base_dir: bool = True, | ||
path_or_hf_repo: str = "mlx-community/whisper-large-v3-turbo", | ||
verbose: Optional[bool] = None, | ||
temperature: Optional[Union[float, Tuple[float, ...]]] = None, | ||
compression_ratio_threshold: Optional[float] = None, | ||
logprob_threshold: Optional[float] = None, | ||
no_speech_threshold: Optional[float] = None, | ||
condition_on_previous_text: Optional[bool] = None, | ||
initial_prompt: Optional[str] = None, | ||
word_timestamps: Optional[bool] = None, | ||
prepend_punctuations: Optional[str] = None, | ||
append_punctuations: Optional[str] = None, | ||
clip_timestamps: Optional[Union[str, List[float]]] = None, | ||
hallucination_silence_threshold: Optional[float] = None, | ||
decode_options: Optional[dict] = None, | ||
): | ||
super().__init__(name="mlx_transcribe") | ||
|
||
self.file_path = file_path | ||
self.path_or_hf_repo = path_or_hf_repo | ||
self.base_dir: Path = base_dir or Path.cwd() | ||
self.path_or_hf_repo: str = path_or_hf_repo | ||
self.verbose: Optional[bool] = verbose | ||
self.temperature: Optional[Union[float, Tuple[float, ...]]] = temperature | ||
self.compression_ratio_threshold: Optional[float] = compression_ratio_threshold | ||
self.logprob_threshold: Optional[float] = logprob_threshold | ||
self.no_speech_threshold: Optional[float] = no_speech_threshold | ||
self.condition_on_previous_text: Optional[bool] = condition_on_previous_text | ||
self.initial_prompt: Optional[str] = initial_prompt | ||
self.word_timestamps: Optional[bool] = word_timestamps | ||
self.prepend_punctuations: Optional[str] = prepend_punctuations | ||
self.append_punctuations: Optional[str] = append_punctuations | ||
self.clip_timestamps: Optional[Union[str, List[float]]] = clip_timestamps | ||
self.hallucination_silence_threshold: Optional[float] = hallucination_silence_threshold | ||
self.decode_options: Optional[dict] = decode_options | ||
|
||
self.register(self.transcribe) | ||
if read_files_in_base_dir: | ||
self.register(self.read_files) | ||
|
||
def transcribe(self, file_path: str) -> str: | ||
def transcribe(self, file_name: str) -> str: | ||
""" | ||
Transcribe uses Apple's MLX Whisper model. | ||
Args: | ||
file_path (str): The path to the audio file to transcribe. | ||
path_or_hf_repo (str): The path to the local model or the Hugging Face repository to use for the model. Defaults to "mlx-community/whisper-large-v3-turbo". | ||
file_name (str): The name of the audio file to transcribe. | ||
Returns: | ||
str: The transcribed text or an error message if the transcription fails. | ||
""" | ||
try: | ||
audio_file_path = str(self.base_dir.joinpath(file_name)) | ||
if audio_file_path is None: | ||
return "No audio file path provided" | ||
|
||
logger.info(f"Transcribing audio file {audio_file_path}") | ||
transcription_kwargs: Dict[str, Any] = { | ||
"path_or_hf_repo": self.path_or_hf_repo, | ||
} | ||
if self.verbose is not None: | ||
transcription_kwargs["verbose"] = self.verbose | ||
if self.temperature is not None: | ||
transcription_kwargs["temperature"] = self.temperature | ||
if self.compression_ratio_threshold is not None: | ||
transcription_kwargs["compression_ratio_threshold"] = self.compression_ratio_threshold | ||
if self.logprob_threshold is not None: | ||
transcription_kwargs["logprob_threshold"] = self.logprob_threshold | ||
if self.no_speech_threshold is not None: | ||
transcription_kwargs["no_speech_threshold"] = self.no_speech_threshold | ||
if self.condition_on_previous_text is not None: | ||
transcription_kwargs["condition_on_previous_text"] = self.condition_on_previous_text | ||
if self.initial_prompt is not None: | ||
transcription_kwargs["initial_prompt"] = self.initial_prompt | ||
if self.word_timestamps is not None: | ||
transcription_kwargs["word_timestamps"] = self.word_timestamps | ||
if self.prepend_punctuations is not None: | ||
transcription_kwargs["prepend_punctuations"] = self.prepend_punctuations | ||
if self.append_punctuations is not None: | ||
transcription_kwargs["append_punctuations"] = self.append_punctuations | ||
if self.clip_timestamps is not None: | ||
transcription_kwargs["clip_timestamps"] = self.clip_timestamps | ||
if self.hallucination_silence_threshold is not None: | ||
transcription_kwargs["hallucination_silence_threshold"] = self.hallucination_silence_threshold | ||
if self.decode_options is not None: | ||
transcription_kwargs.update(self.decode_options) | ||
|
||
transcription = mlx_whisper.transcribe(audio_file_path, **transcription_kwargs) | ||
return transcription.get("text", "") | ||
except Exception as e: | ||
_e = f"Failed to transcribe audio file {e}" | ||
logger.error(_e) | ||
return _e | ||
|
||
def read_files(self) -> str: | ||
"""Returns a list of files in the base directory | ||
Returns: | ||
str: The transcribed text. | ||
str: A JSON string containing the list of files in the base directory. | ||
""" | ||
try: | ||
logger.info(f"Transcribing audio file {file_path}") | ||
text = mlx_whisper.transcribe(file_path, path_or_hf_repo=self.path_or_hf_repo)["text"] | ||
return text | ||
logger.info(f"Reading files in : {self.base_dir}") | ||
return json.dumps([str(file_name) for file_name in self.base_dir.iterdir()], indent=4) | ||
except Exception as e: | ||
logger.error(f"Failed to transcribe audio file {e}") | ||
return f"Error: {e}" | ||
logger.error(f"Error reading files: {e}") | ||
return f"Error reading files: {e}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters