From 255040d6a45a9888de241f98ddaf557efc401830 Mon Sep 17 00:00:00 2001 From: noes14155 Date: Thu, 16 Nov 2023 11:38:24 +0400 Subject: [PATCH] TTS --- bot/chat_gpt.py | 17 ++++++++++---- bot/plugin_manager.py | 10 ++++---- bot/tts.py | 53 +++++++++++++++++++++++++++++++++++++++++++ bot/yt_transcript.py | 18 +++++++++++---- bot_service.py | 24 +++++++++++++++----- main.py | 2 +- 6 files changed, 101 insertions(+), 23 deletions(-) create mode 100644 bot/tts.py diff --git a/bot/chat_gpt.py b/bot/chat_gpt.py index 4f7f5ec..5d46101 100644 --- a/bot/chat_gpt.py +++ b/bot/chat_gpt.py @@ -15,7 +15,7 @@ def __init__(self, api_key: str, api_base: str, default_model: str): openai.api_key = api_key openai.api_base = api_base - self.fetch_models_url = 'http://localhost:1337/models' + self.fetch_models_url = f'{api_base}/models' self.default_model = default_model self.models = [] self.headers = { @@ -30,20 +30,27 @@ def fetch_chat_models(self) -> List[str]: Returns: List[str]: The available chat models. """ + try: response = requests.get(self.fetch_models_url, headers=self.headers) except Exception: - return self.models.append(self.default_model) + self.models.append('gpt-4') + self.models.append('gpt-3.5-turbo') if response.status_code == 200: models_data = response.json() - for model in models_data.get('data'): - if "chat" in model['endpoints'][0]: - self.models.append(model['id']) + self.models.extend( + model['id'] + for model in models_data.get('data') + if "chat" in model['endpoints'][0] + ) else: print(f"Failed to fetch chat models. Status code: {response.status_code}") + if self.default_model not in self.models: self.models.append(self.default_model) return self.models + + def generate_response(self, instruction: str, plugin_result: str, history: List[Dict[str, str]], function: List[Dict[str, Any]] = None, model: str = 'gpt-3.5-turbo') -> Generator[str, None, None]: """ Generates a response using the selected model and input parameters. diff --git a/bot/plugin_manager.py b/bot/plugin_manager.py index 75f47c5..3ba0e3e 100644 --- a/bot/plugin_manager.py +++ b/bot/plugin_manager.py @@ -58,10 +58,10 @@ async def call_function(self, function_name, arguments): :param arguments: The arguments to pass to the function :return: The result of the function call """ - plugin = self.__get_plugin_by_function_name(function_name) - if not plugin: + if plugin := self.__get_plugin_by_function_name(function_name): + return json.dumps(await plugin.execute(function_name, **json.loads(arguments)), default=str) + else: return json.dumps({'error': f'Function {function_name} not found'}) - return json.dumps(await plugin.execute(function_name, **json.loads(arguments)), default=str) def get_plugin_source_name(self, function_name) -> str: """ @@ -71,9 +71,7 @@ def get_plugin_source_name(self, function_name) -> str: :return: The source name of the plugin """ plugin = self.__get_plugin_by_function_name(function_name) - if not plugin: - return '' - return plugin.get_source_name() + return '' if not plugin else plugin.get_source_name() def __get_plugin_by_function_name(self, function_name): """ diff --git a/bot/tts.py b/bot/tts.py new file mode 100644 index 0000000..4e01a3d --- /dev/null +++ b/bot/tts.py @@ -0,0 +1,53 @@ +import os +import tempfile +from gtts import gTTS +import requests +from pydub import AudioSegment + +class TextToSpeech: + def __init__(self, api_key: str, api_base: str, ): + self.headers = { + 'Authorization': f'Bearer {api_key}', + 'Content-Type': 'application/json' + } + self.api_base = api_base + self.voice_id = "XB0fDUnXU5powFXDhCwa" + response = requests.get(f"{self.api_base}/audio/tts/voices") + self.use_openai_tts = response.status_code == 200 + + async def text_to_speech(self, text, filename): + if self.use_openai_tts: + data = { + "text": text, + "voice_id": self.voice_id + } + response = requests.post(f"{self.api_base}/audio/tts", + json=data, + headers=self.headers) + with open(filename, "wb") as f: + f.write(response.content) + else: + tts = gTTS(text=text, lang='en') + with tempfile.NamedTemporaryFile(suffix='.mp3', delete=False) as temp_audio: + temp_audio_path = temp_audio.name + tts.save(temp_audio_path) + os.rename(temp_audio_path, filename) + return filename + + async def create_audio_segments(self, text, chunk_size=300): + audio_filenames = [] + for i in range(0, len(text), chunk_size): + chunk = text[i:i + chunk_size] + audio_filename = f"audio_{i//chunk_size}.mp3" + await self.text_to_speech(chunk, audio_filename) + audio_filenames.append(audio_filename) + return audio_filenames + + def join_audio_segments(self, audio_filenames, output_filename="output.mp3"): + output = AudioSegment.empty() + for filename in audio_filenames: + segment = AudioSegment.from_file(filename) + output += segment + os.remove(filename) # delete the chunk file + output.export(output_filename, format="mp3") + return output_filename \ No newline at end of file diff --git a/bot/yt_transcript.py b/bot/yt_transcript.py index eaf7afa..42c96b7 100644 --- a/bot/yt_transcript.py +++ b/bot/yt_transcript.py @@ -3,6 +3,13 @@ from youtube_transcript_api import YouTubeTranscriptApi class YoutubeTranscript: + def get_transcript(self, video_id, lang_code): + transcript_list = YouTubeTranscriptApi.list_transcripts(video_id) + transcript = transcript_list.find_transcript([lang_code]) + if transcript is None: + transcript = transcript_list.find_manually_created_transcript([lang_code]) + return transcript + async def get_yt_transcript(self, message_content: str, lang: str) -> str: """ Retrieves and formats the transcript of a YouTube video based on a given video URL. @@ -28,20 +35,21 @@ def extract_video_id(message_content: str) -> str: r"(https?://)?(www\.)?(youtube|youtu|youtube-nocookie)\.(com|be)/(watch\?v=|embed/|v/|.+\?v=)?([^&=%\?]{11})" ) match = youtube_link_pattern.search(message_content) - return match.group(6) if match else None + return match[6] if match else None try: video_id = extract_video_id(message_content) if not video_id: return None - transcript_list = YouTubeTranscriptApi.list_transcripts(video_id) - first_transcript = next(iter(transcript_list), None) - if not first_transcript: + transcript = self.get_transcript(video_id, lang) + + #first_transcript = next(iter(transcript_list), None) + if not transcript: return None formatted_transcript = ". ".join( - [entry["text"] for entry in first_transcript.fetch()] + [entry["text"] for entry in transcript.fetch()] )[:2500] response = f"Please provide a summary or additional information for the following YouTube video transcript in a few concise bullet points.\n\n{formatted_transcript}" diff --git a/bot_service.py b/bot_service.py index 8c661e9..9a224cf 100644 --- a/bot_service.py +++ b/bot_service.py @@ -2,7 +2,7 @@ import re import requests from colorama import Fore -from aiogram.types import ReplyKeyboardRemove +from aiogram.types import ReplyKeyboardRemove, FSInputFile from aiogram.utils.keyboard import ReplyKeyboardBuilder, InlineKeyboardBuilder from dotenv import load_dotenv from gradio_client import Client @@ -17,7 +17,8 @@ voice_transcript, web_search, yt_transcript, - chat_gpt + chat_gpt, + tts ) @@ -65,6 +66,7 @@ def __init__(self): self.ft = file_transcript.FileTranscript() self.ig = image_generator.ImageGenerator(HG_IMG2TEXT=self.HG_IMG2TEXT, HG_TEXT2IMAGE=self.HG_TEXT2IMAGE) self.gpt = chat_gpt.ChatGPT(self.GPT_KEY,self.API_BASE,self.DEFAULT_MODEL) + self.tts = tts.TextToSpeech(self.GPT_KEY,self.API_BASE) self.ocr = ocr.OCR(config=" --psm 3 --oem 3") self.db.create_tables() self.plugin = plugin_manager.PluginManager(self.plugin_config) @@ -142,6 +144,7 @@ async def select_persona(self,user_id,user_message): return response, markup async def changemodel(self): + self.gpt.fetch_chat_models() response = "Select from the models or providers" markup = self.generate_keyboard('model') return response, markup @@ -227,9 +230,18 @@ async def chat(self, call, waiting_id, bot, process_prompt = ''): if full_text not in ['', sent_text]: await bot.edit_message_text(chat_id=call.chat.id, message_id=waiting_id, text=full_text, reply_markup=markup) self.cancel_flag = False + + await self.generate_tts(full_text, call, bot) return - + async def generate_tts(self, full_text, call, bot): + audio_chunks = await self.tts.create_audio_segments(full_text) + audio_file_path = self.tts.join_audio_segments(audio_chunks) + audio_file = FSInputFile(audio_file_path) + await bot.send_voice(chat_id=call.chat.id, voice=audio_file) + os.remove(audio_file_path) + + async def voice(self, call, waiting_id, bot): user_id = call.from_user.id lang, persona, model = self.db.get_settings(user_id) @@ -340,9 +352,9 @@ async def __common_generate(self, call, process_prompt = ''): web_text = await self.ws.extract_text_from_website(prompt) if web_text is not None: prompt = web_text - yt_transcript = await self.yt.get_yt_transcript(user_message, lang) - if yt_transcript is not None: - prompt = yt_transcript + #yt_transcript = await self.yt.get_yt_transcript(user_message, lang) + #if yt_transcript is not None: + # prompt = yt_transcript EXTRA_PROMPT = bot_messages["EXTRA_PROMPT"] if user.first_name is not None: bot_messages["bot_prompt"] += f"You should address the user as '{user.first_name}'" diff --git a/main.py b/main.py index 2438322..4590b17 100644 --- a/main.py +++ b/main.py @@ -201,7 +201,7 @@ async def regenerate(callback: types.CallbackQuery): await service.chat(call=service.last_call[callback.from_user.id], waiting_id=waiting_id, bot=bot) @dp.callback_query(F.data == "cancel") -async def regenerate(callback: types.CallbackQuery): +async def cancel(callback: types.CallbackQuery): service.cancel_flag = True @dp.message(F.content_type.in_({'text'}))