From 79fad134571163520d1f91675559161d8827fb3e Mon Sep 17 00:00:00 2001 From: noes14155 Date: Sat, 23 Sep 2023 00:50:36 +0400 Subject: [PATCH] Enhance --- Dockerfile | 9 +++- bot/chat_gpt.py | 27 +++++++--- bot/database.py | 91 +++++++++++++++++++++++++++++++++- bot/file_transcript.py | 107 +++++++++++++++++++++------------------- bot/image_generator.py | 66 +++++++++++++++++++------ bot/language_manager.py | 76 ++++++++++++++++++++++------ bot/ocr.py | 21 ++++++++ bot/plugin_manager.py | 33 ++++++++++--- bot/voice_transcript.py | 50 +++++++++++++++---- bot/web_search.py | 20 ++++++-- bot/yt_transcript.py | 28 +++++++++-- docker-compose.yml | 2 + interference/app.py | 45 +++++++++-------- replit_detector.py | 22 +++++++-- updater.py | 76 ++++++++++++++++++---------- 15 files changed, 506 insertions(+), 167 deletions(-) diff --git a/Dockerfile b/Dockerfile index 69c6cd6..4278c82 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,19 +2,24 @@ FROM python:3.10-slim WORKDIR /app ENV PYTHONUNBUFFERED=1 + RUN apt-get update \ && apt-get install -y --no-install-recommends git flac ffmpeg tesseract-ocr wget \ && apt-get -y clean \ && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* + RUN mkdir -p /usr/share/tesseract-ocr/4.00/tessdata/script/ #RUN wget https://github.com/tesseract-ocr/tessdata_fast/raw/main/script/Devanagari.traineddata -P /usr/share/tesseract-ocr/4.00/tessdata/script/ -ENV TESSDATA_PREFIX=/usr/share/tesseract-ocr/4.00/tessdata + +#ENV TESSDATA_PREFIX=/usr/share/tesseract-ocr/4.00/tessdata + COPY requirements.txt . RUN pip install --upgrade pip \ && pip install -r requirements.txt \ && rm requirements.txt \ && pip cache purge \ && rm -rf ~/.cache/pip/* -COPY . . +COPY . . +#VOLUME /app/personas CMD ["python3", "./main.py"] \ No newline at end of file diff --git a/bot/chat_gpt.py b/bot/chat_gpt.py index 0ab8ed9..e7e184a 100644 --- a/bot/chat_gpt.py +++ b/bot/chat_gpt.py @@ -4,7 +4,14 @@ from typing import List, Dict, Any, Generator class ChatGPT: - def __init__(self,api_key,api_base): + def __init__(self, api_key: str, api_base: str): + """ + Initializes the ChatGPT instance with the provided API key and base URL. + + Args: + api_key (str): The OpenAI API key. + api_base (str): The base URL for the OpenAI API. + """ openai.api_key = api_key openai.api_base = api_base @@ -15,13 +22,17 @@ def __init__(self,api_key,api_base): 'Content-Type': 'application/json' } - + def fetch_chat_models(self) -> List[str]: + """ + Fetches available chat models from the OpenAI API and stores them in the models field. - def fetch_chat_models(self): + Returns: + List[str]: The available chat models. + """ response = requests.get(self.fetch_models_url, headers=self.headers) if response.status_code == 200: - ModelsData = response.json() - for model in ModelsData.get('data'): + models_data = response.json() + for model in models_data.get('data'): if "chat" in model['endpoints'][0]: self.models.append(model['id']) else: @@ -35,25 +46,25 @@ def generate_response( ) -> Generator[str, None, None]: """ Generates a response using the selected model and input parameters. + Args: instruction (str): The instruction for generating the response. plugin_result (str): The plugin result. history (List[Dict[str, str]]): The chat history. - prompt (str): The user prompt. function (List[Dict[str, Any]]): The functions to be used. model (str): The selected model. + Yields: str: Each message in the response stream. """ retries = 0 - while True: + while True: text = '' if not model.startswith('gpt'): plugin_result = '' function = [] print('Unsupported model. Plugins not used') messages = [ - {"role": "system", "content": instruction}, {"role": "system", "content": plugin_result}, *history diff --git a/bot/database.py b/bot/database.py index b803441..b157103 100644 --- a/bot/database.py +++ b/bot/database.py @@ -3,13 +3,41 @@ class Database: + """ + A class that provides methods for interacting with a SQLite database. + + Attributes: + conn: The SQLite database connection object. + + Methods: + __init__(self, db_file): Initializes a new instance of the Database class and connects to the specified SQLite database file. + create_tables(self): Creates the settings and history tables in the database if they do not already exist. + close_connection(self): Closes the database connection. + insert_settings(self, user_id, lang='en', persona='Julie_friend', model='gpt-3.5-turbo'): Inserts settings data for a user into the settings table. + update_settings(self, user_id, lang='en', persona='Julie_friend', model='gpt-3.5-turbo'): Updates the settings data for a user in the settings table. + insert_history(self, user_id, role, content): Inserts history data for a user into the history table. + get_settings(self, user_id): Retrieves the settings data for a user from the settings table. + get_history(self, user_id): Retrieves the history data for a user from the history table. + delete_user_history(self, user_id): Deletes all history data for a user from the history table. + delete_last_2_user_history(self, user_id): Deletes the last 2 history entries for a user from the history table. + """ + def __init__(self, db_file): + """ + Initializes a new instance of the Database class and connects to the specified SQLite database file. + + Args: + db_file (str): The path to the SQLite database file. + """ if not os.path.exists(db_file): open(db_file, "a").close() self.conn = sqlite3.connect(db_file) self.create_tables() def create_tables(self): + """ + Creates the settings and history tables in the database if they do not already exist. + """ settings_query = """CREATE TABLE IF NOT EXISTS settings (user_id INTEGER PRIMARY KEY, lang TEXT DEFAULT 'en', persona TEXT DEFAULT 'Julie_friend', @@ -22,27 +50,65 @@ def create_tables(self): self.conn.commit() def close_connection(self): + """ + Closes the database connection. + """ if self.conn: self.conn.close() - def insert_settings(self, user_id, lang='en', persona='Julie_friend',model='gpt-3.5-turbo'): + def insert_settings(self, user_id, lang='en', persona='Julie_friend', model='gpt-3.5-turbo'): + """ + Inserts settings data for a user into the settings table. + + Args: + user_id (int): The ID of the user. + lang (str, optional): The language setting. Defaults to 'en'. + persona (str, optional): The persona setting. Defaults to 'Julie_friend'. + model (str, optional): The model setting. Defaults to 'gpt-3.5-turbo'. + """ query = """INSERT OR IGNORE INTO settings (user_id, lang, persona, model) VALUES (?, ?, ?, ?)""" self.conn.execute(query, (user_id, lang, persona,model)) self.conn.commit() - def update_settings(self, user_id, lang='en', persona='Julie_friend',model='gpt-3.5-turbo'): + def update_settings(self, user_id, lang='en', persona='Julie_friend', model='gpt-3.5-turbo'): + """ + Updates the settings data for a user in the settings table. + + Args: + user_id (int): The ID of the user. + lang (str, optional): The language setting. Defaults to 'en'. + persona (str, optional): The persona setting. Defaults to 'Julie_friend'. + model (str, optional): The model setting. Defaults to 'gpt-3.5-turbo'. + """ query = """UPDATE settings SET lang=?, persona=?, model=? WHERE user_id=?""" self.conn.execute(query, (lang, persona, model, user_id)) self.conn.commit() def insert_history(self, user_id, role, content): + """ + Inserts history data for a user into the history table. + + Args: + user_id (int): The ID of the user. + role (str): The role of the user. + content (str): The content of the history entry. + """ query = """INSERT INTO history (user_id, role, content) VALUES (?, ?, ?)""" self.conn.execute(query, (user_id, role, content)) self.conn.commit() def get_settings(self, user_id): + """ + Retrieves the settings data for a user from the settings table. + + Args: + user_id (int): The ID of the user. + + Returns: + tuple: A tuple containing the language, persona, and model settings for the user. + """ query = """SELECT lang, persona, model FROM settings WHERE user_id=?""" row = self.conn.execute(query, (user_id,)).fetchone() if row: @@ -52,16 +118,37 @@ def get_settings(self, user_id): return None, None, None def get_history(self, user_id): + """ + Retrieves the history data for a user from the history table. + + Args: + user_id (int): The ID of the user. + + Returns: + list: A list of tuples containing the role and content of each history entry. + """ query = """SELECT role, content FROM history WHERE user_id=?""" rows = self.conn.execute(query, (user_id,)).fetchall() return rows def delete_user_history(self, user_id): + """ + Deletes all history data for a user from the history table. + + Args: + user_id (int): The ID of the user. + """ query = """DELETE FROM history WHERE user_id=?""" self.conn.execute(query, (user_id,)) self.conn.commit() def delete_last_2_user_history(self, user_id): + """ + Deletes the last 2 history entries for a user from the history table. + + Args: + user_id (int): The ID of the user. + """ query = """ DELETE FROM history WHERE rowid IN ( diff --git a/bot/file_transcript.py b/bot/file_transcript.py index c61d728..12d6afc 100644 --- a/bot/file_transcript.py +++ b/bot/file_transcript.py @@ -1,3 +1,5 @@ +from typing import Optional +from aiogram import types import csv import email import os @@ -9,53 +11,35 @@ import pypdf from bs4 import BeautifulSoup class FileTranscript: - def __init__(self): - self.VALID_EXTENSIONS = [ - "txt", - "rtf", - "md", - "html", - "xml", - "csv", - "json", - "js", - "css", - "py", - "java", - "c", - "cpp", - "php", - "rb", - "swift", - "sql", - "sh", - "bat", - "ps1", - "ini", - "cfg", - "conf", - "log", - "svg", - "epub", - "mobi", - "tex", - "docx", - "odt", - "xlsx", - "ods", - "pptx", - "odp", - "eml", - "htaccess", - "nginx.conf", - "pdf", - ] - async def read_document(self, filename): + """ + The FileTranscript class is responsible for reading and downloading various types of files, + such as text documents, spreadsheets, presentations, emails, and more. + """ + + VALID_EXTENSIONS = [ + "txt", "rtf", "md", "html", "xml", "csv", "json", "js", "css", "py", "java", "c", + "cpp", "php", "rb", "swift", "sql", "sh", "bat", "ps1", "ini", "cfg", "conf", "log", + "svg", "epub", "mobi", "tex", "docx", "odt", "xlsx", "ods", "pptx", "odp", "eml", + "htaccess", "nginx.conf", "pdf" + ] + + async def read_document(self, filename: str) -> str: + """ + Reads the contents of a document file based on its extension. + + Args: + filename (str): The name of the file to read. + + Returns: + str: The contents of the file. + """ try: extension = filename.split(".")[-1] - contents = '' + contents = "" + if extension not in self.VALID_EXTENSIONS: - contents = "Invalid document file" + return "Invalid document file" + if extension == "pdf": with open(filename, "rb") as f: pdf_reader = pypdf.PdfReader(f) @@ -64,56 +48,75 @@ async def read_document(self, filename): page_obj = pdf_reader.pages[page_num] page_text = page_obj.extract_text() contents += page_text + elif extension == "docx": doc = docx.Document(filename) contents = "\n".join([paragraph.text for paragraph in doc.paragraphs]) + elif extension in ["xlsx", "ods"]: workbook = openpyxl.load_workbook(filename, read_only=True) sheet = workbook.active for row in sheet.iter_rows(values_only=True): contents += "\t".join([str(cell_value) for cell_value in row]) + "\n" + elif extension in ["pptx", "odp"]: presentation = pptx.Presentation(filename) for slide in presentation.slides: for shape in slide.shapes: if hasattr(shape, "text"): contents += shape.text + "\n" + elif extension == "eml": with open(filename, "r") as f: msg = email.message_from_file(f) for part in msg.walk(): if part.get_content_type() == "text/plain": contents += part.get_payload() + elif extension in ["html", "xml"]: with open(filename, "r") as f: soup = BeautifulSoup(f, "html.parser") contents = soup.get_text() + elif extension == "csv": with open(filename, "r") as f: reader = csv.reader(f) for row in reader: contents += "\t".join(row) + "\n" + else: with open(filename, "r") as f: contents = f.read() + except Exception as e: contents = f"Error during file download: {str(e)}" + return contents - async def download_file(self, bot, message: types.Message): + + async def download_file(self, bot, message: types.Message) -> Optional[str]: + """ + Downloads a file from a Telegram message. + + Args: + bot: The bot instance. + message (types.Message): The Telegram message containing the file. + + Returns: + Optional[str]: The full file path of the downloaded file, or None if there was an error. + """ try: - if message.document is not None: - file = message.document - file_extension = ( - file.file_name.split(".")[-1] if file.file_name is not None else "" - ) - else: + if message.document is None: return None + + file = message.document + file_extension = file.file_name.split(".")[-1] if file.file_name else "" file_path = f"{file.file_id}.{file_extension}" file_dir = "downloaded_files" os.makedirs(file_dir, exist_ok=True) full_file_path = os.path.join(file_dir, file_path) - await bot.download(file=file , destination=full_file_path) + await bot.download(file=file, destination=full_file_path) return full_file_path + except Exception as e: print(f"Error during file download: {str(e)}") return None \ No newline at end of file diff --git a/bot/image_generator.py b/bot/image_generator.py index edd3ea0..cf56204 100644 --- a/bot/image_generator.py +++ b/bot/image_generator.py @@ -1,23 +1,42 @@ import asyncio import aiohttp -from aiogram.types import KeyboardButton, ReplyKeyboardMarkup import gradio as gr from gradio_client import Client import threading import openai -import io class ImageGenerator: - def __init__(self, HG_IMG2TEXT): - self.HG_IMG2TEXT = HG_IMG2TEXT - def load_gradio(): - gr.load("models/stabilityai/stable-diffusion-2-1").launch(server_port=7860) + """ + The `ImageGenerator` class is responsible for generating image captions and images using various AI models. + It uses the `gradio` library to launch a server for image caption generation and the `openai` library to generate images. + """ + + def __init__(self, HG_IMG2TEXT: str): + """ + Initializes the `ImageGenerator` class and launches the `gradio` server in a separate thread. - gradio_thread = threading.Thread(target=load_gradio) + Args: + HG_IMG2TEXT (str): The API endpoint for image-to-text conversion. + """ + self.HG_IMG2TEXT = HG_IMG2TEXT + gradio_thread = threading.Thread(target=self.load_gradio) gradio_thread.start() - - - async def generate_imagecaption(self, url, HG_TOKEN): + + def load_gradio(self): + gr.load("models/stabilityai/stable-diffusion-2-1").launch(server_port=7860) + + async def generate_imagecaption(self, url: str, HG_TOKEN: str) -> str: + """ + Generates a caption for the given image URL by sending a request to the `HG_IMG2TEXT` API endpoint. + Retries the request if there is a server error or the response is still loading. + + Args: + url (str): The URL of the image. + HG_TOKEN (str): The token for authorization. + + Returns: + str: The generated caption for the image. + """ headers = {"Authorization": f"Bearer {HG_TOKEN}"} retries = 0 async with aiohttp.ClientSession() as session: @@ -44,12 +63,31 @@ async def generate_imagecaption(self, url, HG_TOKEN): else: return f"Error: {await resp2.text()}" - async def generate_image(prompt): + async def generate_image(self, prompt: str) -> str: + """ + Generates an image using the `openai` library by providing a prompt. + + Args: + prompt (str): The prompt for generating the image. + + Returns: + str: The generated image as text. + """ client = Client("http://127.0.0.1:7860/") - text = client.predict(prompt, api_name="/predict" ) + text = client.predict(prompt, api_name="/predict") return text - - async def dalle_generate(self, prompt, size): + + async def dalle_generate(self, prompt: str, size: int) -> str: + """ + Generates an image using the `openai` library by providing a prompt and size. + + Args: + prompt (str): The prompt for generating the image. + size (int): The size of the image. + + Returns: + str: The URL of the generated image. + """ try: response = openai.Image.create( prompt=prompt, diff --git a/bot/language_manager.py b/bot/language_manager.py index 8a9ccad..e410142 100644 --- a/bot/language_manager.py +++ b/bot/language_manager.py @@ -2,38 +2,76 @@ import os import yaml class LanguageManager: - def __init__(self, default_lang, database): + """ + The LanguageManager class is responsible for managing languages and loading language files in a Python application. + """ + + def __init__(self, default_lang: str, database): + """ + Initializes the LanguageManager instance with the default language and a database connection. + + Args: + default_lang (str): The default language. + database: The database connection object. + """ self.DEFAULT_LANGUAGE = default_lang self.db_connection = database self.available_lang = {} self.plugin_lang = {} self.load_available_languages() self.load_default_language() + def load_available_languages(self): - if os.path.exists("./language_files/languages.yml"): - with open("./language_files/languages.yml", "r", encoding="utf8") as f: + """ + Loads the available languages from the "languages.yml" file. + """ + language_file_path = "./language_files/languages.yml" + if os.path.exists(language_file_path): + with open(language_file_path, "r", encoding="utf8") as f: self.available_lang = yaml.safe_load(f) else: print("languages.yml does not exist") exit() + def load_default_language(self): - if os.path.exists(f"./language_files/{self.DEFAULT_LANGUAGE}.yml"): - with open( - f"language_files/{self.DEFAULT_LANGUAGE}.yml", "r", encoding="utf8" - ) as file: + """ + Loads the default language from a YAML file based on the default language specified during initialization. + """ + language_file_path = f"./language_files/{self.DEFAULT_LANGUAGE}.yml" + if os.path.exists(language_file_path): + with open(language_file_path, "r", encoding="utf8") as file: self.plugin_lang = yaml.safe_load(file) else: print(f"{self.DEFAULT_LANGUAGE}.yml does not exist") exit() - def set_language(self, user_id, lang): + + def set_language(self, user_id: str, lang: str): + """ + Sets the language for a user and updates the database with the new language. + + Args: + user_id (str): The user ID. + lang (str): The language to set for the user. + """ if not user_id: print("user_id does not exist") - language,persona,model = self.db_connection.get_settings(user_id) + language, persona, model = self.db_connection.get_settings(user_id) if language: - self.db_connection.update_settings(user_id, lang,persona,model) + self.db_connection.update_settings(user_id, lang, persona, model) else: - self.db_connection.insert_settings(user_id, lang,persona,model) - def local_messages(self, user_id): + self.db_connection.insert_settings(user_id, lang, persona, model) + + def local_messages(self, user_id: str): + """ + Retrieves localized messages for a user based on their language and persona. If the language file does not exist, + falls back to the default language. + + Args: + user_id (str): The user ID. + + Returns: + dict: Localized messages for the user. + """ lang, persona, model = self.db_connection.get_settings(user_id) if not lang: lang = self.DEFAULT_LANGUAGE @@ -51,10 +89,18 @@ def local_messages(self, user_id): bot_messages["bot_prompt"] = personas[persona] bot_messages["bot_prompt"] += f"\n\nWhen replying to the user you should act as the above given persona\nIt's currently {datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')}" return bot_messages - def load_personas(self, personas): - for file_name in os.listdir("personas"): + + def load_personas(self, personas: dict): + """ + Loads personas from text files in the "personas" directory and stores them in a dictionary. + + Args: + personas (dict): The dictionary to store the personas. + """ + personas_directory = "personas" + for file_name in os.listdir(personas_directory): if file_name.endswith('.txt'): - file_path = os.path.join("personas", file_name) + file_path = os.path.join(personas_directory, file_name) with open(file_path, 'r', encoding='utf-8') as file: file_content = file.read() persona = file_name.split('.')[0] diff --git a/bot/ocr.py b/bot/ocr.py index 87b8e79..6fca5c6 100644 --- a/bot/ocr.py +++ b/bot/ocr.py @@ -4,10 +4,31 @@ from PIL import Image, ImageEnhance class OCR: + """ + The OCR class is used for performing optical character recognition (OCR) on an image. + It uses the Tesseract OCR engine to extract text from the image. + """ + def __init__(self, config=None): + """ + Initializes an instance of the OCR class with an optional configuration parameter. + + Args: + config (str): Optional configuration parameter for Tesseract OCR. + """ self.config = config def process_image(self, url): + """ + Processes an image from a given URL and returns the extracted text if it contains non-whitespace characters, + otherwise returns None. + + Args: + url (str): The URL of the image to be processed. + + Returns: + str or None: The extracted text from the image, or None if no text is found. + """ try: # Load the image response = requests.get(url) diff --git a/bot/plugin_manager.py b/bot/plugin_manager.py index af81c54..75f47c5 100644 --- a/bot/plugin_manager.py +++ b/bot/plugin_manager.py @@ -19,7 +19,12 @@ class PluginManager: A class to manage the plugins and call the correct functions """ - def __init__(self,plugins): + def __init__(self, plugins): + """ + Initializes the PluginManager with a list of enabled plugins + + :param plugins: A dictionary containing the list of enabled plugins + """ enabled_plugins = plugins.get('plugins', []) plugin_mapping = { 'wolfram': WolframAlphaPlugin, @@ -39,13 +44,19 @@ def __init__(self,plugins): def get_functions_specs(self): """ - Return the list of function specs that can be called by the model + Returns the list of function specs that can be called by the model + + :return: A list of function specs """ - return [spec for specs in map(lambda plugin: plugin.get_spec(), self.plugins) for spec in specs] + return [spec for plugin in self.plugins for spec in plugin.get_spec()] async def call_function(self, function_name, arguments): """ - Call a function based on the name and parameters provided + Calls a function based on the name and parameters provided + + :param function_name: The name of the function to call + :param arguments: The arguments to pass to the function + :return: The result of the function call """ plugin = self.__get_plugin_by_function_name(function_name) if not plugin: @@ -54,7 +65,10 @@ async def call_function(self, function_name, arguments): def get_plugin_source_name(self, function_name) -> str: """ - Return the source name of the plugin + Returns the source name of the plugin + + :param function_name: The name of the function + :return: The source name of the plugin """ plugin = self.__get_plugin_by_function_name(function_name) if not plugin: @@ -62,5 +76,10 @@ def get_plugin_source_name(self, function_name) -> str: return plugin.get_source_name() def __get_plugin_by_function_name(self, function_name): - return next((plugin for plugin in self.plugins - if function_name in map(lambda spec: spec.get('name'), plugin.get_spec())), None) + """ + Returns the plugin that contains the specified function name + + :param function_name: The name of the function + :return: The plugin that contains the function, or None if not found + """ + return next((plugin for plugin in self.plugins if function_name in map(lambda spec: spec.get('name'), plugin.get_spec())), None) diff --git a/bot/voice_transcript.py b/bot/voice_transcript.py index 5aff54a..57d35e8 100644 --- a/bot/voice_transcript.py +++ b/bot/voice_transcript.py @@ -1,43 +1,75 @@ +from typing import Optional +from aiogram import types import os -import asyncio import aiogram.types as types import speech_recognition as sr from pydub import AudioSegment class VoiceTranscript: + """ + The `VoiceTranscript` class is responsible for transcribing audio files using the Google Speech Recognition API. + It provides methods for downloading audio files and transcribing them into text. + """ + def __init__(self): self.rec = sr.Recognizer() - async def transcribe_audio(self, audio_file_path, lang): + + async def transcribe_audio(self, audio_file_path: str, lang: str) -> str: + """ + Transcribes an audio file into text. + + Args: + audio_file_path (str): The path of the audio file. + lang (str): The language code. + + Returns: + str: The transcription of the audio file. + """ try: wav_file_path = audio_file_path.replace(".ogg", ".wav") audio = AudioSegment.from_ogg(audio_file_path) audio.export(wav_file_path, format="wav") + with sr.AudioFile(wav_file_path) as audio_file: audio = self.rec.record(audio_file) - transcription = self.rec.recognize_google(audio, language=f"{lang}") + + transcription = self.rec.recognize_google(audio, language=lang) os.remove(wav_file_path) - + except Exception as e: transcription = f"Error during audio transcription: {str(e)}" - return transcription - async def download_file(self, bot, message: types.Message): + + return transcription + + async def download_file(self, bot, message: types.Message) -> Optional[str]: + """ + Downloads an audio file from a message using a Telegram bot. + + Args: + bot: The Telegram bot instance. + message (types.Message): The message containing the audio file. + + Returns: + Optional[str]: The path of the downloaded file, or None if an error occurs during the download. + """ try: if message.audio is not None: file = message.audio - file_extension = ( - file.file_name.split(".")[-1] if file.file_name is not None else "ogg" - ) + file_extension = file.file_name.split(".")[-1] if file.file_name is not None else "ogg" elif message.voice is not None: file = message.voice file_extension = "ogg" else: return None + file_path = f"{file.file_id}.{file_extension}" file_dir = "downloaded_files" os.makedirs(file_dir, exist_ok=True) full_file_path = os.path.join(file_dir, file_path) + await bot.download(file=file, destination=full_file_path) return full_file_path + except Exception as e: print(f"Error during file download: {str(e)}") return None diff --git a/bot/web_search.py b/bot/web_search.py index 9d68161..d091f00 100644 --- a/bot/web_search.py +++ b/bot/web_search.py @@ -4,15 +4,27 @@ class WebSearch: - - - async def extract_text_from_website(self, url): + """ + The `WebSearch` class is responsible for extracting text from a given website URL using asynchronous HTTP requests and web scraping techniques. + """ + + async def extract_text_from_website(self, url: str) -> str: + """ + Extracts the text content from the website asynchronously. + + Args: + url (str): The URL of the website to extract text from. + + Returns: + str: The extracted text content from the website, or None if the URL is invalid or an exception occurs. + """ if not isinstance(url, str): raise ValueError("url must be a string") parsed_url = urlparse(url) - if parsed_url.scheme == "" or parsed_url.netloc == "": + if not parsed_url.scheme or not parsed_url.netloc: return None + try: async with aiohttp.ClientSession() as session: async with session.get(url) as response: diff --git a/bot/yt_transcript.py b/bot/yt_transcript.py index f617492..eaf7afa 100644 --- a/bot/yt_transcript.py +++ b/bot/yt_transcript.py @@ -3,25 +3,47 @@ from youtube_transcript_api import YouTubeTranscriptApi class YoutubeTranscript: - async def get_yt_transcript(self, message_content, lang): - def extract_video_id(message_content): + async def get_yt_transcript(self, message_content: str, lang: str) -> str: + """ + Retrieves and formats the transcript of a YouTube video based on a given video URL. + + Args: + message_content (str): The message content which includes a YouTube video URL. + lang (str): The language parameter. + + Returns: + str: The formatted transcript or None if an error occurs. + """ + def extract_video_id(message_content: str) -> str: + """ + Extracts the video ID from a YouTube video URL. + + Args: + message_content (str): The message content which includes a YouTube video URL. + + Returns: + str: The video ID or None if no match is found. + """ youtube_link_pattern = re.compile( r"(https?://)?(www\.)?(youtube|youtu|youtube-nocookie)\.(com|be)/(watch\?v=|embed/|v/|.+\?v=)?([^&=%\?]{11})" ) match = youtube_link_pattern.search(message_content) return match.group(6) if match else None + try: video_id = extract_video_id(message_content) if not video_id: return None + transcript_list = YouTubeTranscriptApi.list_transcripts(video_id) first_transcript = next(iter(transcript_list), None) if not first_transcript: return None - # translated_transcript = first_transcript.translate(f"{lang}") + formatted_transcript = ". ".join( [entry["text"] for entry in first_transcript.fetch()] )[:2500] + response = f"Please provide a summary or additional information for the following YouTube video transcript in a few concise bullet points.\n\n{formatted_transcript}" return response except Exception as e: diff --git a/docker-compose.yml b/docker-compose.yml index a68af22..c60a6d3 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -7,6 +7,8 @@ services: context: . dockerfile: Dockerfile restart: always +# volumes: +# - ./home/personas:/app/personas # g4f_server: # container_name: g4f_server # ports: diff --git a/interference/app.py b/interference/app.py index 3e42997..84634b1 100644 --- a/interference/app.py +++ b/interference/app.py @@ -41,27 +41,30 @@ @app.route('/models') def get_models(): - models = [ - { - "id": "gpt-4", - "endpoints": [ - "/chat/completions" - ] - }, - { - "id": "gpt-3.5-turbo", - "endpoints": [ - "/chat/completions" - ] - } - ] - - response = { - "data": models, - "object": "list" - } - - return jsonify(response) + """ + Returns a JSON response containing a list of models and their endpoints. + """ + models = [ + { + "id": "gpt-4", + "endpoints": [ + "/chat/completions" + ] + }, + { + "id": "gpt-3.5-turbo", + "endpoints": [ + "/chat/completions" + ] + } + ] + + response = { + "data": models, + "object": "list" + } + + return jsonify(response) @app.route("/chat/completions", methods=["POST"]) def chat_completions(): diff --git a/replit_detector.py b/replit_detector.py index 7a1f5b4..28b5e1c 100644 --- a/replit_detector.py +++ b/replit_detector.py @@ -1,21 +1,33 @@ from flask import Flask import threading import os -import sys class ReplitFlaskApp: + """ + A Flask application that can be run on Repl.it or locally. + """ + def __init__(self): + """ + Initializes the ReplitFlaskApp class by creating a Flask application and defining a route for the root URL. + """ self.app = Flask(__name__) @self.app.route('/', methods=['GET', 'POST', 'CONNECT', 'PUT', 'DELETE', 'PATCH', 'OPTIONS', 'TRACE', 'HEAD']) def start(): - return 'chatbot is running. Access at https://t.me/gp4free_bot' + """ + The route handler for the root URL. + """ + return 'chatbot is running.' def run(self): + """ + Runs the Flask application on Repl.it if the environment variables REPL_ID and REPL_OWNER are present. + Otherwise, it returns None. + """ if 'REPL_ID' in os.environ and 'REPL_OWNER' in os.environ: print('Running in Repl.it') - thread = threading.Thread(target=self.app.run, kwargs={'host': '0.0.0.0', 'port':8080, 'debug':False, 'use_reloader':False}) + thread = threading.Thread(target=self.app.run, kwargs={'host': '0.0.0.0', 'port': 8080, 'debug': False, 'use_reloader': False}) thread.start() - else: - return None + return None diff --git a/updater.py b/updater.py index 59e6825..e8fa936 100644 --- a/updater.py +++ b/updater.py @@ -4,73 +4,99 @@ import git import requests class SelfUpdating: + """ + The `SelfUpdating` class is responsible for checking if there is a new version available for a given GitHub repository + and updating the local files if necessary. + """ - def __init__(self, repo_url, branch="master"): + def __init__(self, repo_url: str, branch: str = "master"): + """ + Initializes the `SelfUpdating` object with the repository URL and branch. + + Args: + repo_url (str): The URL of the GitHub repository. + branch (str, optional): The branch to track for updates. Defaults to "master". + """ self.repo_url = f'https://github.com/{repo_url}' self.repo_name = repo_url self.branch = branch self.current_version = self.get_current_version() - + def check_for_update(self): + """ + Checks if a new version is available and updates the files if necessary. + """ latest_tag = self.get_latest_tag_from_github(self.repo_name) - + if latest_tag != self.current_version: print(f"New version {latest_tag} available! Updating...") self.update() else: print(f"Already on latest version {self.current_version}") - + def update(self): - temp_dir = "./temp/" + """ + Updates the files by cloning the repository, pulling the latest changes, and copying the updated files. + """ + temp_dir = "./temp/" if not os.path.exists(temp_dir): git.Git(".").clone(self.repo_url, temp_dir) try: - # Checkout latest commit repo = git.Repo(temp_dir) - repo.git.checkout('master') + repo.git.checkout(self.branch) repo.git.pull() - except git.exc.GitCommandError as e: - # Handle pull error + except git.exc.GitCommandError as e: print(f"Git pull failed: {e}") return - # Walk through temp dir + changed_files = [] for root, dirs, files in os.walk(temp_dir): if '.git' in dirs: dirs.remove('.git') for file in files: file_path = os.path.join(root, file) - # Calculate hash of file current_hash = hashlib.sha256(open(file_path, "rb").read()).hexdigest() - # Construct destination path - destination_path = os.path.join(os.getcwd(), file_path.split(temp_dir, 1)[1]) + destination_path = os.path.join(os.getcwd(), file_path.split(temp_dir, 1)[1]) if os.path.exists(destination_path): - # Calculate hash of existing file existing_hash = hashlib.sha256(open(destination_path, "rb").read()).hexdigest() else: existing_hash = "" - # Only overwrite if hashes don't match if current_hash != existing_hash: - shutil.copyfile(file_path, destination_path) + shutil.copyfile(file_path, destination_path) changed_files.append(file) - # Delete temp dir + try: shutil.rmtree(temp_dir) except OSError as e: print(f"Error removing temp dir: {e}") - + self.current_version = self.get_current_version() - - def get_current_version(self): + + def get_current_version(self) -> str: + """ + Returns the current version of the repository. + + Returns: + str: The current version of the repository. + """ # Return current version somehow return "0.6" - def get_latest_tag_from_github(self,repo_url): + def get_latest_tag_from_github(self, repo_url: str) -> str: + """ + Retrieves the latest tag from the GitHub API. + + Args: + repo_url (str): The URL of the GitHub repository. + + Returns: + str: The latest tag from the GitHub API. + """ api_url = f"https://api.github.com/repos/{repo_url}/releases/latest" response = requests.get(api_url) if response.ok: - release = response.json() - return release["name"] + release = response.json() + return release["name"] else: - print(f"Error fetching latest release: {response}") - return "None" \ No newline at end of file + print(f"Error fetching latest release: {response}") + return "None" \ No newline at end of file