From 4624f89f66c613049459ebeb55dab459ab021faa Mon Sep 17 00:00:00 2001 From: Ashpreet Bedi Date: Tue, 29 Oct 2024 10:16:19 +0000 Subject: [PATCH] Improve MLX tool --- cookbook/tools/mlx_transcribe.py | 42 ++++++++++ cookbook/tools/trascribe.py | 24 ------ phi/tools/mlx_transcribe.py | 131 ++++++++++++++++++++++++++----- pyproject.toml | 1 + 4 files changed, 156 insertions(+), 42 deletions(-) create mode 100644 cookbook/tools/mlx_transcribe.py delete mode 100644 cookbook/tools/trascribe.py diff --git a/cookbook/tools/mlx_transcribe.py b/cookbook/tools/mlx_transcribe.py new file mode 100644 index 0000000000..2ba6d204e1 --- /dev/null +++ b/cookbook/tools/mlx_transcribe.py @@ -0,0 +1,42 @@ +""" +MLX Transcribe: A tool for transcribing audio files using MLX Whisper + +Requirements: +1. ffmpeg - Install using: + - macOS: `brew install ffmpeg` + - Ubuntu: `sudo apt-get install ffmpeg` + - Windows: Download from https://ffmpeg.org/download.html + +2. mlx-whisper library: + pip install mlx-whisper + +Example Usage: +- Place your audio files in the 'storage/audio' directory + Eg: download https://www.ted.com/talks/reid_hoffman_and_kevin_scott_the_evolution_of_ai_and_how_it_will_impact_human_creativity +- Run this script to transcribe audio files +- Supports various audio formats (mp3, mp4, wav, etc.) +""" + +from pathlib import Path +from phi.agent import Agent +from phi.model.openai import OpenAIChat +from phi.tools.mlx_transcribe import MLXTranscribe + +# Get audio files from storage/audio directory +phidata_root_dir = Path(__file__).parent.parent.parent.resolve() +audio_storage_dir = phidata_root_dir.joinpath("storage/audio") +if not audio_storage_dir.exists(): + audio_storage_dir.mkdir(exist_ok=True, parents=True) + +agent = Agent( + name="Transcription Agent", + model=OpenAIChat(id="gpt-4o"), + tools=[MLXTranscribe(base_dir=audio_storage_dir)], + instructions=[ + "To transcribe an audio file, use the `transcribe` tool with the name of the audio file as the argument.", + "You can find all available audio files using the `read_files` tool.", + ], + markdown=True, +) + +agent.print_response("Summarize the reid hoffman ted talk, split into sections", stream=True) diff --git a/cookbook/tools/trascribe.py b/cookbook/tools/trascribe.py deleted file mode 100644 index a33f6ee881..0000000000 --- a/cookbook/tools/trascribe.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -MLX Transcribe Tools will need ffmpeg installed to work. -Install ffmpeg using `brew install ffmpeg` on macOS. - -MLX Transcribe Tools will need the `mlx-whisper` library installed to work. -Install `mlx-whisper` using `pip install mlx-whisper` -""" - -from phi.tools.mlx_transcribe import MLXTranscribeTools -from phi.agent import Agent -from phi.model.openai import OpenAIChat -import os - -file_path = os.path.expanduser("~/path/to/file.mp3") - - -agent = Agent( - model=OpenAIChat(id="gpt-4o"), - tools=[MLXTranscribeTools(file_path=file_path)], - name="Transcribe Agent", - markdown=True, -) - -agent.print_response(f"Transcribe file {file_path}") diff --git a/phi/tools/mlx_transcribe.py b/phi/tools/mlx_transcribe.py index dd536f1a0c..17e164ff59 100644 --- a/phi/tools/mlx_transcribe.py +++ b/phi/tools/mlx_transcribe.py @@ -1,42 +1,137 @@ """ -MLX Transcribe Tools will need ffmpeg installed to work. -Install ffmpeg using `brew install ffmpeg` on macOS. +MLX Transcribe Tools - Audio Transcription using Apple's MLX Framework + +Requirements: + - ffmpeg: Required for audio processing + macOS: brew install ffmpeg + Ubuntu: apt-get install ffmpeg + Windows: Download from https://ffmpeg.org/download.html + + - mlx-whisper: Install via pip + pip install mlx-whisper + +This module provides tools for transcribing audio files using the MLX Whisper model, +optimized for Apple Silicon processors. It supports various audio formats and +provides high-quality transcription capabilities. """ +import json +from pathlib import Path +from typing import Optional, Union, Tuple, List, Dict, Any + from phi.tools import Toolkit from phi.utils.log import logger -from typing import Optional try: - import mlx_whisper # type: ignore + import mlx_whisper except ImportError: raise ImportError("`mlx_whisper` not installed. Please install using `pip install mlx-whisper`") -class MLXTranscribeTools(Toolkit): - def __init__(self, file_path: str, path_or_hf_repo: Optional[str] = "mlx-community/whisper-large-v3-turbo"): - super().__init__(name="transcribe") +class MLXTranscribe(Toolkit): + def __init__( + self, + base_dir: Optional[Path] = None, + read_files_in_base_dir: bool = True, + path_or_hf_repo: str = "mlx-community/whisper-large-v3-turbo", + verbose: Optional[bool] = None, + temperature: Optional[Union[float, Tuple[float, ...]]] = None, + compression_ratio_threshold: Optional[float] = None, + logprob_threshold: Optional[float] = None, + no_speech_threshold: Optional[float] = None, + condition_on_previous_text: Optional[bool] = None, + initial_prompt: Optional[str] = None, + word_timestamps: Optional[bool] = None, + prepend_punctuations: Optional[str] = None, + append_punctuations: Optional[str] = None, + clip_timestamps: Optional[Union[str, List[float]]] = None, + hallucination_silence_threshold: Optional[float] = None, + decode_options: Optional[dict] = None, + ): + super().__init__(name="mlx_transcribe") - self.file_path = file_path - self.path_or_hf_repo = path_or_hf_repo + self.base_dir: Path = base_dir or Path.cwd() + self.path_or_hf_repo: str = path_or_hf_repo + self.verbose: Optional[bool] = verbose + self.temperature: Optional[Union[float, Tuple[float, ...]]] = temperature + self.compression_ratio_threshold: Optional[float] = compression_ratio_threshold + self.logprob_threshold: Optional[float] = logprob_threshold + self.no_speech_threshold: Optional[float] = no_speech_threshold + self.condition_on_previous_text: Optional[bool] = condition_on_previous_text + self.initial_prompt: Optional[str] = initial_prompt + self.word_timestamps: Optional[bool] = word_timestamps + self.prepend_punctuations: Optional[str] = prepend_punctuations + self.append_punctuations: Optional[str] = append_punctuations + self.clip_timestamps: Optional[Union[str, List[float]]] = clip_timestamps + self.hallucination_silence_threshold: Optional[float] = hallucination_silence_threshold + self.decode_options: Optional[dict] = decode_options self.register(self.transcribe) + if read_files_in_base_dir: + self.register(self.read_files) - def transcribe(self, file_path: str) -> str: + def transcribe(self, file_name: str) -> str: """ Transcribe uses Apple's MLX Whisper model. Args: - file_path (str): The path to the audio file to transcribe. - path_or_hf_repo (str): The path to the local model or the Hugging Face repository to use for the model. Defaults to "mlx-community/whisper-large-v3-turbo". + file_name (str): The name of the audio file to transcribe. + + Returns: + str: The transcribed text or an error message if the transcription fails. + """ + try: + audio_file_path = str(self.base_dir.joinpath(file_name)) + if audio_file_path is None: + return "No audio file path provided" + + logger.info(f"Transcribing audio file {audio_file_path}") + transcription_kwargs: Dict[str, Any] = { + "path_or_hf_repo": self.path_or_hf_repo, + } + if self.verbose is not None: + transcription_kwargs["verbose"] = self.verbose + if self.temperature is not None: + transcription_kwargs["temperature"] = self.temperature + if self.compression_ratio_threshold is not None: + transcription_kwargs["compression_ratio_threshold"] = self.compression_ratio_threshold + if self.logprob_threshold is not None: + transcription_kwargs["logprob_threshold"] = self.logprob_threshold + if self.no_speech_threshold is not None: + transcription_kwargs["no_speech_threshold"] = self.no_speech_threshold + if self.condition_on_previous_text is not None: + transcription_kwargs["condition_on_previous_text"] = self.condition_on_previous_text + if self.initial_prompt is not None: + transcription_kwargs["initial_prompt"] = self.initial_prompt + if self.word_timestamps is not None: + transcription_kwargs["word_timestamps"] = self.word_timestamps + if self.prepend_punctuations is not None: + transcription_kwargs["prepend_punctuations"] = self.prepend_punctuations + if self.append_punctuations is not None: + transcription_kwargs["append_punctuations"] = self.append_punctuations + if self.clip_timestamps is not None: + transcription_kwargs["clip_timestamps"] = self.clip_timestamps + if self.hallucination_silence_threshold is not None: + transcription_kwargs["hallucination_silence_threshold"] = self.hallucination_silence_threshold + if self.decode_options is not None: + transcription_kwargs.update(self.decode_options) + + transcription = mlx_whisper.transcribe(audio_file_path, **transcription_kwargs) + return transcription.get("text", "") + except Exception as e: + _e = f"Failed to transcribe audio file {e}" + logger.error(_e) + return _e + + def read_files(self) -> str: + """Returns a list of files in the base directory Returns: - str: The transcribed text. + str: A JSON string containing the list of files in the base directory. """ try: - logger.info(f"Transcribing audio file {file_path}") - text = mlx_whisper.transcribe(file_path, path_or_hf_repo=self.path_or_hf_repo)["text"] - return text + logger.info(f"Reading files in : {self.base_dir}") + return json.dumps([str(file_name) for file_name in self.base_dir.iterdir()], indent=4) except Exception as e: - logger.error(f"Failed to transcribe audio file {e}") - return f"Error: {e}" + logger.error(f"Error reading files: {e}") + return f"Error reading files: {e}" diff --git a/pyproject.toml b/pyproject.toml index 9b07a908f1..f426dd3821 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,6 +125,7 @@ module = [ "langchain.*", "llama_index.*", "mistralai.*", + "mlx_whisper.*", "nest_asyncio.*", "newspaper.*", "numpy.*",