diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index c470f25bb..2ec85b147 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -42,6 +42,12 @@ jobs: cd ./Tests/RAG pytest test_RAG_Library_2.py + - name: Test RAG Notes functions with pytest + run: | + pwd + cd ./Tests/RAG_QA_Chat + pytest test_notes_search.py + - name: Test SQLite lib functions with pytest run: | pwd diff --git a/App_Function_Libraries/Audio/Audio_Files.py b/App_Function_Libraries/Audio/Audio_Files.py index b3eabbde3..1fa0b1afa 100644 --- a/App_Function_Libraries/Audio/Audio_Files.py +++ b/App_Function_Libraries/Audio/Audio_Files.py @@ -117,16 +117,15 @@ def process_audio_files(audio_urls, audio_file, whisper_model, api_name, api_key progress = [] all_transcriptions = [] all_summaries = [] - #v2 + temp_files = [] # Keep track of temporary files + def format_transcription_with_timestamps(segments): if keep_timestamps: formatted_segments = [] for segment in segments: start = segment.get('Time_Start', 0) end = segment.get('Time_End', 0) - text = segment.get('Text', '').strip() # Ensure text is stripped of leading/trailing spaces - - # Add the formatted timestamp and text to the list, followed by a newline + text = segment.get('Text', '').strip() formatted_segments.append(f"[{start:.2f}-{end:.2f}] {text}") # Join the segments with a newline to ensure proper formatting @@ -191,205 +190,64 @@ def convert_mp3_to_wav(mp3_file_path): 'language': chunk_language } - # Process multiple URLs - urls = [url.strip() for url in audio_urls.split('\n') if url.strip()] - - for i, url in enumerate(urls): - update_progress(f"Processing URL {i + 1}/{len(urls)}: {url}") - - # Download and process audio file - audio_file_path = download_audio_file(url, use_cookies, cookies) - if not os.path.exists(audio_file_path): - update_progress(f"Downloaded file not found: {audio_file_path}") - failed_count += 1 - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - continue - - temp_files.append(audio_file_path) - update_progress("Audio file downloaded successfully.") - - # Re-encode MP3 to fix potential issues - reencoded_mp3_path = reencode_mp3(audio_file_path) - if not os.path.exists(reencoded_mp3_path): - update_progress(f"Re-encoded file not found: {reencoded_mp3_path}") - failed_count += 1 - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - continue - - temp_files.append(reencoded_mp3_path) - - # Convert re-encoded MP3 to WAV - wav_file_path = convert_mp3_to_wav(reencoded_mp3_path) - if not os.path.exists(wav_file_path): - update_progress(f"Converted WAV file not found: {wav_file_path}") - failed_count += 1 - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - continue - - temp_files.append(wav_file_path) - - # Initialize transcription - transcription = "" - - # Transcribe audio - if diarize: - segments = speech_to_text(wav_file_path, whisper_model=whisper_model, diarize=True) - else: - segments = speech_to_text(wav_file_path, whisper_model=whisper_model) - - # Handle segments nested under 'segments' key - if isinstance(segments, dict) and 'segments' in segments: - segments = segments['segments'] - - if isinstance(segments, list): - # Log first 5 segments for debugging - logging.debug(f"Segments before formatting: {segments[:5]}") - transcription = format_transcription_with_timestamps(segments) - logging.debug(f"Formatted transcription (first 500 chars): {transcription[:500]}") - update_progress("Audio transcribed successfully.") - else: - update_progress("Unexpected segments format received from speech_to_text.") - logging.error(f"Unexpected segments format: {segments}") - failed_count += 1 - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - continue - - if not transcription.strip(): - update_progress("Transcription is empty.") - failed_count += 1 - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - else: - # Apply chunking - chunked_text = improved_chunking_process(transcription, chunk_options) - - # Summarize - logging.debug(f"Audio Transcription API Name: {api_name}") - if api_name: - try: - summary = perform_summarization(api_name, chunked_text, custom_prompt_input, api_key) - update_progress("Audio summarized successfully.") - except Exception as e: - logging.error(f"Error during summarization: {str(e)}") - summary = "Summary generation failed" - failed_count += 1 - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - else: - summary = "No summary available (API not provided)" + # Process URLs if provided + if audio_urls: + urls = [url.strip() for url in audio_urls.split('\n') if url.strip()] + for i, url in enumerate(urls): + try: + update_progress(f"Processing URL {i + 1}/{len(urls)}: {url}") - all_transcriptions.append(transcription) - all_summaries.append(summary) + # Download and process audio file + audio_file_path = download_audio_file(url, use_cookies, cookies) + if not audio_file_path: + raise FileNotFoundError(f"Failed to download audio from URL: {url}") - # Use custom_title if provided, otherwise use the original filename - title = custom_title if custom_title else os.path.basename(wav_file_path) - - # Add to database - add_media_with_keywords( - url=url, - title=title, - media_type='audio', - content=transcription, - keywords=custom_keywords, - prompt=custom_prompt_input, - summary=summary, - transcription_model=whisper_model, - author="Unknown", - ingestion_date=datetime.now().strftime('%Y-%m-%d') - ) - update_progress("Audio file processed and added to database.") - processed_count += 1 - log_counter( - metric_name="audio_files_processed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - - # Process uploaded file if provided - if audio_file: - url = generate_unique_id() - if os.path.getsize(audio_file.name) > MAX_FILE_SIZE: - update_progress( - f"Uploaded file size exceeds the maximum limit of {MAX_FILE_SIZE / (1024 * 1024):.2f}MB. Skipping this file.") - else: - try: - # Re-encode MP3 to fix potential issues - reencoded_mp3_path = reencode_mp3(audio_file.name) - if not os.path.exists(reencoded_mp3_path): - update_progress(f"Re-encoded file not found: {reencoded_mp3_path}") - return update_progress("Processing failed: Re-encoded file not found"), "", "" + temp_files.append(audio_file_path) + # Process the audio file + reencoded_mp3_path = reencode_mp3(audio_file_path) temp_files.append(reencoded_mp3_path) - # Convert re-encoded MP3 to WAV wav_file_path = convert_mp3_to_wav(reencoded_mp3_path) - if not os.path.exists(wav_file_path): - update_progress(f"Converted WAV file not found: {wav_file_path}") - return update_progress("Processing failed: Converted WAV file not found"), "", "" - temp_files.append(wav_file_path) - # Initialize transcription - transcription = "" - - if diarize: - segments = speech_to_text(wav_file_path, whisper_model=whisper_model, diarize=True) - else: - segments = speech_to_text(wav_file_path, whisper_model=whisper_model) + # Transcribe audio + segments = speech_to_text(wav_file_path, whisper_model=whisper_model, diarize=diarize) - # Handle segments nested under 'segments' key + # Handle segments format if isinstance(segments, dict) and 'segments' in segments: segments = segments['segments'] - if isinstance(segments, list): - transcription = format_transcription_with_timestamps(segments) - else: - update_progress("Unexpected segments format received from speech_to_text.") - logging.error(f"Unexpected segments format: {segments}") + if not isinstance(segments, list): + raise ValueError("Unexpected segments format received from speech_to_text") - chunked_text = improved_chunking_process(transcription, chunk_options) + transcription = format_transcription_with_timestamps(segments) + if not transcription.strip(): + raise ValueError("Empty transcription generated") - logging.debug(f"Audio Transcription API Name: {api_name}") - if api_name: + # Initialize summary with default value + summary = "No summary available" + + # Attempt summarization if API is provided + if api_name and api_name.lower() != "none": try: - summary = perform_summarization(api_name, chunked_text, custom_prompt_input, api_key) + chunked_text = improved_chunking_process(transcription, chunk_options) + summary_result = perform_summarization(api_name, chunked_text, custom_prompt_input, api_key) + if summary_result: + summary = summary_result update_progress("Audio summarized successfully.") except Exception as e: - logging.error(f"Error during summarization: {str(e)}") + logging.error(f"Summarization failed: {str(e)}") summary = "Summary generation failed" - else: - summary = "No summary available (API not provided)" + # Add to results all_transcriptions.append(transcription) all_summaries.append(summary) - # Use custom_title if provided, otherwise use the original filename + # Add to database title = custom_title if custom_title else os.path.basename(wav_file_path) - add_media_with_keywords( - url="Uploaded File", + url=url, title=title, media_type='audio', content=transcription, @@ -400,65 +258,112 @@ def convert_mp3_to_wav(mp3_file_path): author="Unknown", ingestion_date=datetime.now().strftime('%Y-%m-%d') ) - update_progress("Uploaded file processed and added to database.") + processed_count += 1 - log_counter( - metric_name="audio_files_processed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) + update_progress(f"Successfully processed URL {i + 1}") + log_counter("audio_files_processed_total", 1, {"whisper_model": whisper_model, "api_name": api_name}) + except Exception as e: - update_progress(f"Error processing uploaded file: {str(e)}") - logging.error(f"Error processing uploaded file: {str(e)}") failed_count += 1 - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - return update_progress("Processing failed: Error processing uploaded file"), "", "" - # Final cleanup - if not keep_original: - cleanup_files() + update_progress(f"Failed to process URL {i + 1}: {str(e)}") + log_counter("audio_files_failed_total", 1, {"whisper_model": whisper_model, "api_name": api_name}) + continue - end_time = time.time() - processing_time = end_time - start_time - # Log processing time - log_histogram( - metric_name="audio_processing_time_seconds", - value=processing_time, - labels={"whisper_model": whisper_model, "api_name": api_name} - ) + # Process uploaded file if provided + if audio_file: + try: + update_progress("Processing uploaded file...") + if os.path.getsize(audio_file.name) > MAX_FILE_SIZE: + raise ValueError(f"File size exceeds maximum limit of {MAX_FILE_SIZE / (1024 * 1024):.2f}MB") - # Optionally, log total counts - log_counter( - metric_name="total_audio_files_processed", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=processed_count - ) + reencoded_mp3_path = reencode_mp3(audio_file.name) + temp_files.append(reencoded_mp3_path) - log_counter( - metric_name="total_audio_files_failed", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=failed_count - ) + wav_file_path = convert_mp3_to_wav(reencoded_mp3_path) + temp_files.append(wav_file_path) + + # Transcribe audio + segments = speech_to_text(wav_file_path, whisper_model=whisper_model, diarize=diarize) + + if isinstance(segments, dict) and 'segments' in segments: + segments = segments['segments'] + + if not isinstance(segments, list): + raise ValueError("Unexpected segments format received from speech_to_text") + transcription = format_transcription_with_timestamps(segments) + if not transcription.strip(): + raise ValueError("Empty transcription generated") + + # Initialize summary with default value + summary = "No summary available" - final_progress = update_progress("All processing complete.") - final_transcriptions = "\n\n".join(all_transcriptions) - final_summaries = "\n\n".join(all_summaries) + # Attempt summarization if API is provided + if api_name and api_name.lower() != "none": + try: + chunked_text = improved_chunking_process(transcription, chunk_options) + summary_result = perform_summarization(api_name, chunked_text, custom_prompt_input, api_key) + if summary_result: + summary = summary_result + update_progress("Audio summarized successfully.") + except Exception as e: + logging.error(f"Summarization failed: {str(e)}") + summary = "Summary generation failed" + + # Add to results + all_transcriptions.append(transcription) + all_summaries.append(summary) + + # Add to database + title = custom_title if custom_title else os.path.basename(wav_file_path) + add_media_with_keywords( + url="Uploaded File", + title=title, + media_type='audio', + content=transcription, + keywords=custom_keywords, + prompt=custom_prompt_input, + summary=summary, + transcription_model=whisper_model, + author="Unknown", + ingestion_date=datetime.now().strftime('%Y-%m-%d') + ) + + processed_count += 1 + update_progress("Successfully processed uploaded file") + log_counter("audio_files_processed_total", 1, {"whisper_model": whisper_model, "api_name": api_name}) + + except Exception as e: + failed_count += 1 + update_progress(f"Failed to process uploaded file: {str(e)}") + log_counter("audio_files_failed_total", 1, {"whisper_model": whisper_model, "api_name": api_name}) + + # Cleanup temporary files + if not keep_original: + cleanup_files() + + # Log processing metrics + processing_time = time.time() - start_time + log_histogram("audio_processing_time_seconds", processing_time, + {"whisper_model": whisper_model, "api_name": api_name}) + log_counter("total_audio_files_processed", processed_count, + {"whisper_model": whisper_model, "api_name": api_name}) + log_counter("total_audio_files_failed", failed_count, + {"whisper_model": whisper_model, "api_name": api_name}) + + # Prepare final output + final_progress = update_progress(f"Processing complete. Processed: {processed_count}, Failed: {failed_count}") + final_transcriptions = "\n\n".join(all_transcriptions) if all_transcriptions else "No transcriptions available" + final_summaries = "\n\n".join(all_summaries) if all_summaries else "No summaries available" return final_progress, final_transcriptions, final_summaries except Exception as e: - logging.error(f"Error processing audio files: {str(e)}") - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - cleanup_files() - return update_progress(f"Processing failed: {str(e)}"), "", "" + logging.error(f"Error in process_audio_files: {str(e)}") + log_counter("audio_files_failed_total", 1, {"whisper_model": whisper_model, "api_name": api_name}) + if not keep_original: + cleanup_files() + return update_progress(f"Processing failed: {str(e)}"), "No transcriptions available", "No summaries available" def format_transcription_with_timestamps(segments, keep_timestamps): diff --git a/App_Function_Libraries/Benchmarks_Evaluations/ms_g_eval.py b/App_Function_Libraries/Benchmarks_Evaluations/ms_g_eval.py index a17387980..a6c7651d3 100644 --- a/App_Function_Libraries/Benchmarks_Evaluations/ms_g_eval.py +++ b/App_Function_Libraries/Benchmarks_Evaluations/ms_g_eval.py @@ -24,7 +24,7 @@ wait_random_exponential, ) -from App_Function_Libraries.Chat import chat_api_call +from App_Function_Libraries.Chat.Chat_Functions import chat_api_call # ####################################################################################################################### diff --git a/App_Function_Libraries/Books/Book_Ingestion_Lib.py b/App_Function_Libraries/Books/Book_Ingestion_Lib.py index 488d1ba3e..72a9c2f18 100644 --- a/App_Function_Libraries/Books/Book_Ingestion_Lib.py +++ b/App_Function_Libraries/Books/Book_Ingestion_Lib.py @@ -385,109 +385,103 @@ def process_markdown_content(markdown_content, file_path, title, author, keyword return f"Document '{title}' imported successfully. Database result: {result}" -def import_file_handler(file, - title, - author, - keywords, - system_prompt, - custom_prompt, - auto_summarize, - api_name, - api_key, - max_chunk_size, - chunk_overlap, - custom_chapter_pattern - ): +def import_file_handler(files, + author, + keywords, + system_prompt, + custom_prompt, + auto_summarize, + api_name, + api_key, + max_chunk_size, + chunk_overlap, + custom_chapter_pattern): try: - log_counter("file_import_attempt", labels={"file_name": file.name}) - - # Handle max_chunk_size - if isinstance(max_chunk_size, str): - max_chunk_size = int(max_chunk_size) if max_chunk_size.strip() else 4000 - elif not isinstance(max_chunk_size, int): - max_chunk_size = 4000 # Default value if not a string or int - - # Handle chunk_overlap - if isinstance(chunk_overlap, str): - chunk_overlap = int(chunk_overlap) if chunk_overlap.strip() else 0 - elif not isinstance(chunk_overlap, int): - chunk_overlap = 0 # Default value if not a string or int - - chunk_options = { - 'method': 'chapter', - 'max_size': max_chunk_size, - 'overlap': chunk_overlap, - 'custom_chapter_pattern': custom_chapter_pattern if custom_chapter_pattern else None - } + if not files: + return "No files uploaded." - if file is None: - log_counter("file_import_error", labels={"error": "No file uploaded"}) - return "No file uploaded." + # Convert single file to list for consistent processing + if not isinstance(files, list): + files = [files] - file_path = file.name - if not os.path.exists(file_path): - log_counter("file_import_error", labels={"error": "File not found", "file_name": file.name}) - return "Uploaded file not found." + results = [] + for file in files: + log_counter("file_import_attempt", labels={"file_name": file.name}) - start_time = datetime.now() + # Handle max_chunk_size and chunk_overlap + chunk_size = int(max_chunk_size) if isinstance(max_chunk_size, (str, int)) else 4000 + overlap = int(chunk_overlap) if isinstance(chunk_overlap, (str, int)) else 0 - if file_path.lower().endswith('.epub'): - status = import_epub( - file_path, - title, - author, - keywords, - custom_prompt=custom_prompt, - system_prompt=system_prompt, - summary=None, - auto_summarize=auto_summarize, - api_name=api_name, - api_key=api_key, - chunk_options=chunk_options, - custom_chapter_pattern=custom_chapter_pattern - ) - log_counter("epub_import_success", labels={"file_name": file.name}) - result = f"📚 EPUB Imported Successfully:\n{status}" - elif file.name.lower().endswith('.zip'): - status = process_zip_file( - zip_file=file, - title=title, - author=author, - keywords=keywords, - custom_prompt=custom_prompt, - system_prompt=system_prompt, - summary=None, - auto_summarize=auto_summarize, - api_name=api_name, - api_key=api_key, - chunk_options=chunk_options - ) - log_counter("zip_import_success", labels={"file_name": file.name}) - result = f"📦 ZIP Processed Successfully:\n{status}" - elif file.name.lower().endswith(('.chm', '.html', '.pdf', '.xml', '.opml')): - file_type = file.name.split('.')[-1].upper() - log_counter("unsupported_file_type", labels={"file_type": file_type}) - result = f"{file_type} file import is not yet supported." - else: - log_counter("unsupported_file_type", labels={"file_type": file.name.split('.')[-1]}) - result = "❌ Unsupported file type. Please upload an `.epub` file or a `.zip` file containing `.epub` files." + chunk_options = { + 'method': 'chapter', + 'max_size': chunk_size, + 'overlap': overlap, + 'custom_chapter_pattern': custom_chapter_pattern if custom_chapter_pattern else None + } - end_time = datetime.now() - processing_time = (end_time - start_time).total_seconds() - log_histogram("file_import_duration", processing_time, labels={"file_name": file.name}) + file_path = file.name + if not os.path.exists(file_path): + results.append(f"❌ File not found: {file.name}") + continue - return result + start_time = datetime.now() + + # Extract title from filename + title = os.path.splitext(os.path.basename(file_path))[0] + + if file_path.lower().endswith('.epub'): + status = import_epub( + file_path, + title=title, # Use filename as title + author=author, + keywords=keywords, + custom_prompt=custom_prompt, + system_prompt=system_prompt, + summary=None, + auto_summarize=auto_summarize, + api_name=api_name, + api_key=api_key, + chunk_options=chunk_options, + custom_chapter_pattern=custom_chapter_pattern + ) + log_counter("epub_import_success", labels={"file_name": file.name}) + results.append(f"📚 {file.name}: {status}") + + elif file_path.lower().endswith('.zip'): + status = process_zip_file( + zip_file=file, + title=None, # Let each file use its own name + author=author, + keywords=keywords, + custom_prompt=custom_prompt, + system_prompt=system_prompt, + summary=None, + auto_summarize=auto_summarize, + api_name=api_name, + api_key=api_key, + chunk_options=chunk_options + ) + log_counter("zip_import_success", labels={"file_name": file.name}) + results.append(f"📦 {file.name}: {status}") + else: + results.append(f"❌ Unsupported file type: {file.name}") + continue + + end_time = datetime.now() + processing_time = (end_time - start_time).total_seconds() + log_histogram("file_import_duration", processing_time, labels={"file_name": file.name}) + + return "\n\n".join(results) except ValueError as ve: logging.exception(f"Error parsing input values: {str(ve)}") - log_counter("file_import_error", labels={"error": "Invalid input", "file_name": file.name}) return f"❌ Error: Invalid input for chunk size or overlap. Please enter valid numbers." except Exception as e: logging.exception(f"Error during file import: {str(e)}") - log_counter("file_import_error", labels={"error": str(e), "file_name": file.name}) return f"❌ Error during import: {str(e)}" + def read_epub(file_path): """ Reads and extracts text from an EPUB file. @@ -568,9 +562,9 @@ def ingest_text_file(file_path, title=None, author=None, keywords=None): # Add the text file to the database add_media_with_keywords( - url=file_path, + url="its_a_book", title=title, - media_type='document', + media_type='book', content=content, keywords=keywords, prompt='No prompt for text files', diff --git a/App_Function_Libraries/Chat.py b/App_Function_Libraries/Chat/Chat_Functions.py similarity index 86% rename from App_Function_Libraries/Chat.py rename to App_Function_Libraries/Chat/Chat_Functions.py index a86ab66cb..f595d40e7 100644 --- a/App_Function_Libraries/Chat.py +++ b/App_Function_Libraries/Chat/Chat_Functions.py @@ -1,4 +1,4 @@ -# Chat.py +# Chat_Functions.py # Chat functions for interacting with the LLMs as chatbots import base64 # Imports @@ -6,6 +6,7 @@ import logging import os import re +import sqlite3 import tempfile import time from datetime import datetime @@ -14,7 +15,8 @@ # External Imports # # Local Imports -from App_Function_Libraries.DB.DB_Manager import get_conversation_name, save_chat_history_to_database +from App_Function_Libraries.DB.DB_Manager import start_new_conversation, delete_messages_in_conversation, save_message +from App_Function_Libraries.DB.RAG_QA_Chat_DB import get_db_connection, get_conversation_name from App_Function_Libraries.LLM_API_Calls import chat_with_openai, chat_with_anthropic, chat_with_cohere, \ chat_with_groq, chat_with_openrouter, chat_with_deepseek, chat_with_mistral, chat_with_huggingface from App_Function_Libraries.LLM_API_Calls_Local import chat_with_aphrodite, chat_with_local_llm, chat_with_ollama, \ @@ -27,6 +29,16 @@ # # Functions: +def approximate_token_count(history): + total_text = '' + for user_msg, bot_msg in history: + if user_msg: + total_text += user_msg + ' ' + if bot_msg: + total_text += bot_msg + ' ' + total_tokens = len(total_text.split()) + return total_tokens + def chat_api_call(api_endpoint, api_key, input_data, prompt, temp, system_message=None): log_counter("chat_api_call_attempt", labels={"api_endpoint": api_endpoint}) start_time = time.time() @@ -173,56 +185,58 @@ def save_chat_history_to_db_wrapper(chatbot, conversation_id, media_content, med log_counter("save_chat_history_to_db_attempt") start_time = time.time() logging.info(f"Attempting to save chat history. Media content type: {type(media_content)}") + try: - # Extract the media_id and media_name from the media_content - media_id = None - if isinstance(media_content, dict): + # First check if we can access the database + try: + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT 1") + except sqlite3.DatabaseError as db_error: + logging.error(f"Database is corrupted or inaccessible: {str(db_error)}") + return conversation_id, "Database error: The database file appears to be corrupted. Please contact support." + + # Now attempt the save + if not conversation_id: + # Only for new conversations, not updates media_id = None - logging.debug(f"Media content keys: {media_content.keys()}") - if 'content' in media_content: + if isinstance(media_content, dict) and 'content' in media_content: try: content = media_content['content'] - if isinstance(content, str): - content_json = json.loads(content) - elif isinstance(content, dict): - content_json = content - else: - raise ValueError(f"Unexpected content type: {type(content)}") - - # Use the webpage_url as the media_id + content_json = content if isinstance(content, dict) else json.loads(content) media_id = content_json.get('webpage_url') - # Use the title as the media_name - media_name = content_json.get('title') - - logging.info(f"Extracted media_id: {media_id}, media_name: {media_name}") - except json.JSONDecodeError: - logging.error("Failed to decode JSON from media_content['content']") - except Exception as e: - logging.error(f"Error processing media_content: {str(e)}") + media_name = media_name or content_json.get('title', 'Unnamed Media') + except (json.JSONDecodeError, AttributeError) as e: + logging.error(f"Error processing media content: {str(e)}") + media_id = "unknown_media" + media_name = media_name or "Unnamed Media" else: - logging.warning("'content' key not found in media_content") - else: - logging.warning(f"media_content is not a dictionary. Type: {type(media_content)}") - - if media_id is None: - # If we couldn't find a media_id, we'll use a placeholder - media_id = "unknown_media" - logging.warning(f"Unable to extract media_id from media_content. Using placeholder: {media_id}") - - if media_name is None: - media_name = "Unnamed Media" - logging.warning(f"Unable to extract media_name from media_content. Using placeholder: {media_name}") + media_id = "unknown_media" + media_name = media_name or "Unnamed Media" + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + conversation_title = f"{media_name}_{timestamp}" + conversation_id = start_new_conversation(title=conversation_title, media_id=media_id) + logging.info(f"Created new conversation with ID: {conversation_id}") + + # For both new and existing conversations + try: + delete_messages_in_conversation(conversation_id) + for user_msg, assistant_msg in chatbot: + if user_msg: + save_message(conversation_id, "user", user_msg) + if assistant_msg: + save_message(conversation_id, "assistant", assistant_msg) + except sqlite3.DatabaseError as db_error: + logging.error(f"Database error during message save: {str(db_error)}") + return conversation_id, "Database error: Unable to save messages. Please try again or contact support." - # Generate a unique conversation name using media_id and current timestamp - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - conversation_name = f"{media_name}_{timestamp}" - - new_conversation_id = save_chat_history_to_database(chatbot, conversation_id, media_id, media_name, - conversation_name) save_duration = time.time() - start_time log_histogram("save_chat_history_to_db_duration", save_duration) log_counter("save_chat_history_to_db_success") - return new_conversation_id, f"Chat history saved successfully as {conversation_name}!" + + return conversation_id, "Chat history saved successfully!" + except Exception as e: log_counter("save_chat_history_to_db_error", labels={"error": str(e)}) error_message = f"Failed to save chat history: {str(e)}" diff --git a/App_Function_Libraries/Chat/__init__.py b/App_Function_Libraries/Chat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/App_Function_Libraries/Chunk_Lib.py b/App_Function_Libraries/Chunk_Lib.py index 2a37d3538..8d9ff2b60 100644 --- a/App_Function_Libraries/Chunk_Lib.py +++ b/App_Function_Libraries/Chunk_Lib.py @@ -106,6 +106,7 @@ def load_document(file_path: str) -> str: def improved_chunking_process(text: str, chunk_options: Dict[str, Any] = None) -> List[Dict[str, Any]]: logging.debug("Improved chunking process started...") + logging.debug(f"Received chunk_options: {chunk_options}") # Extract JSON metadata if present json_content = {} @@ -125,49 +126,70 @@ def improved_chunking_process(text: str, chunk_options: Dict[str, Any] = None) - text = text[len(header_text):].strip() logging.debug(f"Extracted header text: {header_text}") - options = chunk_options.copy() if chunk_options else {} + # Make a copy of chunk_options and ensure values are correct types + options = {} if chunk_options: - options.update(chunk_options) - - chunk_method = options.get('method', 'words') - max_size = options.get('max_size', 2000) - overlap = options.get('overlap', 0) - language = options.get('language', None) + try: + options['method'] = str(chunk_options.get('method', 'words')) + options['max_size'] = int(chunk_options.get('max_size', 2000)) + options['overlap'] = int(chunk_options.get('overlap', 0)) + # Handle language specially - it can be None + lang = chunk_options.get('language') + options['language'] = str(lang) if lang is not None else None + logging.debug(f"Processed options: {options}") + except Exception as e: + logging.error(f"Error processing chunk options: {e}") + raise + else: + options = {'method': 'words', 'max_size': 2000, 'overlap': 0, 'language': None} + logging.debug("Using default options") - if language is None: - language = detect_language(text) + if options.get('language') is None: + detected_lang = detect_language(text) + options['language'] = str(detected_lang) + logging.debug(f"Detected language: {options['language']}") - if chunk_method == 'json': - chunks = chunk_text_by_json(text, max_size=max_size, overlap=overlap) - else: - chunks = chunk_text(text, chunk_method, max_size, overlap, language) + try: + if options['method'] == 'json': + chunks = chunk_text_by_json(text, max_size=options['max_size'], overlap=options['overlap']) + else: + chunks = chunk_text(text, options['method'], options['max_size'], options['overlap'], options['language']) + logging.debug(f"Created {len(chunks)} chunks using method {options['method']}") + except Exception as e: + logging.error(f"Error in chunking process: {e}") + raise chunks_with_metadata = [] total_chunks = len(chunks) - for i, chunk in enumerate(chunks): - metadata = { - 'chunk_index': i + 1, - 'total_chunks': total_chunks, - 'chunk_method': chunk_method, - 'max_size': max_size, - 'overlap': overlap, - 'language': language, - 'relative_position': (i + 1) / total_chunks - } - metadata.update(json_content) # Add the extracted JSON content to metadata - metadata['header_text'] = header_text # Add the header text to metadata - - if chunk_method == 'json': - chunk_text_content = json.dumps(chunk['json'], ensure_ascii=False) - else: - chunk_text_content = chunk + try: + for i, chunk in enumerate(chunks): + metadata = { + 'chunk_index': i + 1, + 'total_chunks': total_chunks, + 'chunk_method': options['method'], + 'max_size': options['max_size'], + 'overlap': options['overlap'], + 'language': options['language'], + 'relative_position': float((i + 1) / total_chunks) + } + metadata.update(json_content) + metadata['header_text'] = header_text + + if options['method'] == 'json': + chunk_text_content = json.dumps(chunk['json'], ensure_ascii=False) + else: + chunk_text_content = chunk - chunks_with_metadata.append({ - 'text': chunk_text_content, - 'metadata': metadata - }) + chunks_with_metadata.append({ + 'text': chunk_text_content, + 'metadata': metadata + }) - return chunks_with_metadata + logging.debug(f"Successfully created metadata for all chunks") + return chunks_with_metadata + except Exception as e: + logging.error(f"Error creating chunk metadata: {e}") + raise def multi_level_chunking(text: str, method: str, max_size: int, overlap: int, language: str) -> List[str]: @@ -220,24 +242,35 @@ def determine_chunk_position(relative_position: float) -> str: def chunk_text_by_words(text: str, max_words: int = 300, overlap: int = 0, language: str = None) -> List[str]: logging.debug("chunk_text_by_words...") - if language is None: - language = detect_language(text) + logging.debug(f"Parameters: max_words={max_words}, overlap={overlap}, language={language}") - if language.startswith('zh'): # Chinese - import jieba - words = list(jieba.cut(text)) - elif language == 'ja': # Japanese - import fugashi - tagger = fugashi.Tagger() - words = [word.surface for word in tagger(text)] - else: # Default to simple splitting for other languages - words = text.split() + try: + if language is None: + language = detect_language(text) + logging.debug(f"Detected language: {language}") + + if language.startswith('zh'): # Chinese + import jieba + words = list(jieba.cut(text)) + elif language == 'ja': # Japanese + import fugashi + tagger = fugashi.Tagger() + words = [word.surface for word in tagger(text)] + else: # Default to simple splitting for other languages + words = text.split() + + logging.debug(f"Total words: {len(words)}") - chunks = [] - for i in range(0, len(words), max_words - overlap): - chunk = ' '.join(words[i:i + max_words]) - chunks.append(chunk) - return post_process_chunks(chunks) + chunks = [] + for i in range(0, len(words), max_words - overlap): + chunk = ' '.join(words[i:i + max_words]) + chunks.append(chunk) + logging.debug(f"Created chunk {len(chunks)} with {len(chunk.split())} words") + + return post_process_chunks(chunks) + except Exception as e: + logging.error(f"Error in chunk_text_by_words: {e}") + raise def chunk_text_by_sentences(text: str, max_sentences: int = 10, overlap: int = 0, language: str = None) -> List[str]: @@ -338,24 +371,24 @@ def get_chunk_metadata(chunk: str, full_text: str, chunk_type: str = "generic", """ chunk_length = len(chunk) start_index = full_text.find(chunk) - end_index = start_index + chunk_length if start_index != -1 else None + end_index = start_index + chunk_length if start_index != -1 else -1 # Calculate a hash for the chunk chunk_hash = hashlib.md5(chunk.encode()).hexdigest() metadata = { - 'start_index': start_index, - 'end_index': end_index, - 'word_count': len(chunk.split()), - 'char_count': chunk_length, + 'start_index': int(start_index), + 'end_index': int(end_index), + 'word_count': int(len(chunk.split())), + 'char_count': int(chunk_length), 'chunk_type': chunk_type, 'language': language, 'chunk_hash': chunk_hash, - 'relative_position': start_index / len(full_text) if len(full_text) > 0 and start_index != -1 else 0 + 'relative_position': float(start_index / len(full_text) if len(full_text) > 0 and start_index != -1 else 0) } if chunk_type == "chapter": - metadata['chapter_number'] = chapter_number + metadata['chapter_number'] = int(chapter_number) if chapter_number is not None else None metadata['chapter_pattern'] = chapter_pattern return metadata diff --git a/App_Function_Libraries/DB/Character_Chat_DB.py b/App_Function_Libraries/DB/Character_Chat_DB.py index 4050b3fb8..63bbac637 100644 --- a/App_Function_Libraries/DB/Character_Chat_DB.py +++ b/App_Function_Libraries/DB/Character_Chat_DB.py @@ -72,6 +72,54 @@ def initialize_database(): ); """) + # Create FTS5 virtual table for CharacterCards + cursor.execute(""" + CREATE VIRTUAL TABLE IF NOT EXISTS CharacterCards_fts USING fts5( + name, + description, + personality, + scenario, + system_prompt, + content='CharacterCards', + content_rowid='id' + ); + """) + + # Create triggers to keep FTS5 table in sync with CharacterCards + cursor.executescript(""" + CREATE TRIGGER IF NOT EXISTS CharacterCards_ai AFTER INSERT ON CharacterCards BEGIN + INSERT INTO CharacterCards_fts( + rowid, + name, + description, + personality, + scenario, + system_prompt + ) VALUES ( + new.id, + new.name, + new.description, + new.personality, + new.scenario, + new.system_prompt + ); + END; + + CREATE TRIGGER IF NOT EXISTS CharacterCards_ad AFTER DELETE ON CharacterCards BEGIN + DELETE FROM CharacterCards_fts WHERE rowid = old.id; + END; + + CREATE TRIGGER IF NOT EXISTS CharacterCards_au AFTER UPDATE ON CharacterCards BEGIN + UPDATE CharacterCards_fts SET + name = new.name, + description = new.description, + personality = new.personality, + scenario = new.scenario, + system_prompt = new.system_prompt + WHERE rowid = new.id; + END; + """) + # Create CharacterChats table cursor.execute(""" CREATE TABLE IF NOT EXISTS CharacterChats ( @@ -155,6 +203,7 @@ def setup_chat_database(): setup_chat_database() + ######################################################################################################## # # Character Card handling @@ -560,6 +609,7 @@ def delete_character_chat(chat_id: int) -> bool: finally: conn.close() + def fetch_keywords_for_chats(keywords: List[str]) -> List[int]: """ Fetch chat IDs associated with any of the specified keywords. @@ -589,6 +639,7 @@ def fetch_keywords_for_chats(keywords: List[str]) -> List[int]: finally: conn.close() + def save_chat_history_to_character_db(character_id: int, conversation_name: str, chat_history: List[Tuple[str, str]]) -> Optional[int]: """Save chat history to the CharacterChats table. @@ -596,9 +647,6 @@ def save_chat_history_to_character_db(character_id: int, conversation_name: str, """ return add_character_chat(character_id, conversation_name, chat_history) -def migrate_chat_to_media_db(): - pass - def search_db(query: str, fields: List[str], where_clause: str = "", page: int = 1, results_per_page: int = 5) -> List[Dict[str, Any]]: """ @@ -696,6 +744,316 @@ def fetch_all_chats() -> List[Dict[str, Any]]: logging.error(f"Error fetching all chats: {str(e)}") return [] + +def search_character_chat(query: str, fts_top_k: int = 10, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]: + """ + Perform a full-text search on the Character Chat database. + + Args: + query: Search query string. + fts_top_k: Maximum number of results to return. + relevant_media_ids: Optional list of character IDs to filter results. + + Returns: + List of search results with content and metadata. + """ + if not query.strip(): + return [] + + try: + # Construct a WHERE clause to limit the search to relevant character IDs + where_clause = "" + if relevant_media_ids: + placeholders = ','.join(['?'] * len(relevant_media_ids)) + where_clause = f"CharacterChats.character_id IN ({placeholders})" + + # Perform full-text search using existing search_db function + results = search_db(query, ["conversation_name", "chat_history"], where_clause, results_per_page=fts_top_k) + + # Format results + formatted_results = [] + for r in results: + formatted_results.append({ + "content": r['chat_history'], + "metadata": { + "chat_id": r['id'], + "conversation_name": r['conversation_name'], + "character_id": r['character_id'] + } + }) + + return formatted_results + + except Exception as e: + logging.error(f"Error in search_character_chat: {e}") + return [] + + +def search_character_cards(query: str, fts_top_k: int = 10, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]: + """ + Perform a full-text search on the Character Cards database. + + Args: + query: Search query string. + fts_top_k: Maximum number of results to return. + relevant_media_ids: Optional list of character IDs to filter results. + + Returns: + List of search results with content and metadata. + """ + if not query.strip(): + return [] + + try: + conn = sqlite3.connect(chat_DB_PATH) + cursor = conn.cursor() + + # Construct the query + sql_query = """ + SELECT CharacterCards.id, CharacterCards.name, CharacterCards.description, CharacterCards.personality, CharacterCards.scenario + FROM CharacterCards_fts + JOIN CharacterCards ON CharacterCards_fts.rowid = CharacterCards.id + WHERE CharacterCards_fts MATCH ? + """ + + params = [query] + + # Add filtering by character IDs if provided + if relevant_media_ids: + placeholders = ','.join(['?'] * len(relevant_media_ids)) + sql_query += f" AND CharacterCards.id IN ({placeholders})" + params.extend(relevant_media_ids) + + sql_query += " LIMIT ?" + params.append(fts_top_k) + + cursor.execute(sql_query, params) + rows = cursor.fetchall() + columns = [description[0] for description in cursor.description] + + results = [dict(zip(columns, row)) for row in rows] + + # Format results + formatted_results = [] + for r in results: + content = f"Name: {r['name']}\nDescription: {r['description']}\nPersonality: {r['personality']}\nScenario: {r['scenario']}" + formatted_results.append({ + "content": content, + "metadata": { + "character_id": r['id'], + "name": r['name'] + } + }) + + return formatted_results + + except Exception as e: + logging.error(f"Error in search_character_cards: {e}") + return [] + finally: + conn.close() + + +def fetch_character_ids_by_keywords(keywords: List[str]) -> List[int]: + """ + Fetch character IDs associated with any of the specified keywords. + + Args: + keywords (List[str]): List of keywords to search for. + + Returns: + List[int]: List of character IDs associated with the keywords. + """ + if not keywords: + return [] + + conn = sqlite3.connect(chat_DB_PATH) + cursor = conn.cursor() + try: + # Assuming 'tags' column in CharacterCards table stores tags as JSON array + placeholders = ','.join(['?'] * len(keywords)) + sql_query = f""" + SELECT DISTINCT id FROM CharacterCards + WHERE EXISTS ( + SELECT 1 FROM json_each(tags) + WHERE json_each.value IN ({placeholders}) + ) + """ + cursor.execute(sql_query, keywords) + rows = cursor.fetchall() + character_ids = [row[0] for row in rows] + return character_ids + except Exception as e: + logging.error(f"Error in fetch_character_ids_by_keywords: {e}") + return [] + finally: + conn.close() + + +################################################################### +# +# Character Keywords + +def view_char_keywords(): + try: + with sqlite3.connect(chat_DB_PATH) as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT DISTINCT keyword + FROM CharacterCards + CROSS JOIN json_each(tags) + WHERE json_valid(tags) + ORDER BY keyword + """) + keywords = cursor.fetchall() + if keywords: + keyword_list = [k[0] for k in keywords] + return "### Current Character Keywords:\n" + "\n".join( + [f"- {k}" for k in keyword_list]) + return "No keywords found." + except Exception as e: + return f"Error retrieving keywords: {str(e)}" + + +def add_char_keywords(name: str, keywords: str): + try: + keywords_list = [k.strip() for k in keywords.split(",") if k.strip()] + with sqlite3.connect('character_chat.db') as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT tags FROM CharacterCards WHERE name = ?", + (name,) + ) + result = cursor.fetchone() + if not result: + return "Character not found." + + current_tags = result[0] if result[0] else "[]" + current_keywords = set(current_tags[1:-1].split(',')) if current_tags != "[]" else set() + updated_keywords = current_keywords.union(set(keywords_list)) + + cursor.execute( + "UPDATE CharacterCards SET tags = ? WHERE name = ?", + (str(list(updated_keywords)), name) + ) + conn.commit() + return f"Successfully added keywords to character {name}" + except Exception as e: + return f"Error adding keywords: {str(e)}" + + +def delete_char_keyword(char_name: str, keyword: str) -> str: + """ + Delete a keyword from a character's tags. + + Args: + char_name (str): The name of the character + keyword (str): The keyword to delete + + Returns: + str: Success/failure message + """ + try: + with sqlite3.connect(chat_DB_PATH) as conn: + cursor = conn.cursor() + + # First, check if the character exists + cursor.execute("SELECT tags FROM CharacterCards WHERE name = ?", (char_name,)) + result = cursor.fetchone() + + if not result: + return f"Character '{char_name}' not found." + + # Parse existing tags + current_tags = json.loads(result[0]) if result[0] else [] + + if keyword not in current_tags: + return f"Keyword '{keyword}' not found in character '{char_name}' tags." + + # Remove the keyword + updated_tags = [tag for tag in current_tags if tag != keyword] + + # Update the character's tags + cursor.execute( + "UPDATE CharacterCards SET tags = ? WHERE name = ?", + (json.dumps(updated_tags), char_name) + ) + conn.commit() + + logging.info(f"Keyword '{keyword}' deleted from character '{char_name}'") + return f"Successfully deleted keyword '{keyword}' from character '{char_name}'." + + except Exception as e: + error_msg = f"Error deleting keyword: {str(e)}" + logging.error(error_msg) + return error_msg + + +def export_char_keywords_to_csv() -> Tuple[str, str]: + """ + Export all character keywords to a CSV file with associated metadata. + + Returns: + Tuple[str, str]: (status_message, file_path) + """ + import csv + from tempfile import NamedTemporaryFile + from datetime import datetime + + try: + # Create a temporary CSV file + temp_file = NamedTemporaryFile(mode='w+', delete=False, suffix='.csv', newline='') + + with sqlite3.connect(chat_DB_PATH) as conn: + cursor = conn.cursor() + + # Get all characters and their tags + cursor.execute(""" + SELECT + name, + tags, + (SELECT COUNT(*) FROM CharacterChats WHERE CharacterChats.character_id = CharacterCards.id) as chat_count + FROM CharacterCards + WHERE json_valid(tags) + ORDER BY name + """) + + results = cursor.fetchall() + + # Process the results to create rows for the CSV + csv_rows = [] + for name, tags_json, chat_count in results: + tags = json.loads(tags_json) if tags_json else [] + for tag in tags: + csv_rows.append([ + tag, # keyword + name, # character name + chat_count # number of chats + ]) + + # Write to CSV + writer = csv.writer(temp_file) + writer.writerow(['Keyword', 'Character Name', 'Number of Chats']) + writer.writerows(csv_rows) + + temp_file.close() + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + status_msg = f"Successfully exported {len(csv_rows)} character keyword entries to CSV." + logging.info(status_msg) + + return status_msg, temp_file.name + + except Exception as e: + error_msg = f"Error exporting keywords: {str(e)}" + logging.error(error_msg) + return error_msg, "" + +# +# End of Character chat keyword functions +###################################################### + + # # End of Character_Chat_DB.py ####################################################################################################################### diff --git a/App_Function_Libraries/DB/DB_Backups.py b/App_Function_Libraries/DB/DB_Backups.py new file mode 100644 index 000000000..e3b5b784d --- /dev/null +++ b/App_Function_Libraries/DB/DB_Backups.py @@ -0,0 +1,160 @@ +# Backup_Manager.py +# +# Imports: +import os +import shutil +import sqlite3 +from datetime import datetime +import logging +# +# Local Imports: +from App_Function_Libraries.DB.Character_Chat_DB import chat_DB_PATH +from App_Function_Libraries.DB.RAG_QA_Chat_DB import get_rag_qa_db_path +from App_Function_Libraries.Utils.Utils import get_project_relative_path +# +# End of Imports +####################################################################################################################### +# +# Functions: + +def init_backup_directory(backup_base_dir: str, db_name: str) -> str: + """Initialize backup directory for a specific database.""" + backup_dir = os.path.join(backup_base_dir, db_name) + os.makedirs(backup_dir, exist_ok=True) + return backup_dir + + +def create_backup(db_path: str, backup_dir: str, db_name: str) -> str: + """Create a full backup of the database.""" + try: + db_path = os.path.abspath(db_path) + backup_dir = os.path.abspath(backup_dir) + + logging.info(f"Creating backup:") + logging.info(f" DB Path: {db_path}") + logging.info(f" Backup Dir: {backup_dir}") + logging.info(f" DB Name: {db_name}") + + # Create subdirectory based on db_name + specific_backup_dir = os.path.join(backup_dir, db_name) + os.makedirs(specific_backup_dir, exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_file = os.path.join(specific_backup_dir, f"{db_name}_backup_{timestamp}.db") + logging.info(f" Full backup path: {backup_file}") + + # Create a backup using SQLite's backup API + with sqlite3.connect(db_path) as source, \ + sqlite3.connect(backup_file) as target: + source.backup(target) + + logging.info(f"Backup created successfully: {backup_file}") + return f"Backup created: {backup_file}" + except Exception as e: + error_msg = f"Failed to create backup: {str(e)}" + logging.error(error_msg) + return error_msg + + +def create_incremental_backup(db_path: str, backup_dir: str, db_name: str) -> str: + """Create an incremental backup using VACUUM INTO.""" + try: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_file = os.path.join(backup_dir, + f"{db_name}_incremental_{timestamp}.sqlib") + + with sqlite3.connect(db_path) as conn: + conn.execute(f"VACUUM INTO '{backup_file}'") + + logging.info(f"Incremental backup created: {backup_file}") + return f"Incremental backup created: {backup_file}" + except Exception as e: + error_msg = f"Failed to create incremental backup: {str(e)}" + logging.error(error_msg) + return error_msg + + +def list_backups(backup_dir: str) -> str: + """List all available backups.""" + try: + backups = [f for f in os.listdir(backup_dir) + if f.endswith(('.db', '.sqlib'))] + backups.sort(reverse=True) # Most recent first + return "\n".join(backups) if backups else "No backups found" + except Exception as e: + error_msg = f"Failed to list backups: {str(e)}" + logging.error(error_msg) + return error_msg + + +def restore_single_db_backup(db_path: str, backup_dir: str, db_name: str, backup_name: str) -> str: + """Restore database from a backup file.""" + try: + logging.info(f"Restoring backup: {backup_name}") + backup_path = os.path.join(backup_dir, backup_name) + if not os.path.exists(backup_path): + logging.error(f"Backup file not found: {backup_name}") + return f"Backup file not found: {backup_name}" + + # Create a timestamp for the current db + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + current_backup = os.path.join(backup_dir, + f"{db_name}_pre_restore_{timestamp}.db") + + # Backup current database before restore + logging.info(f"Creating backup of current database: {current_backup}") + shutil.copy2(db_path, current_backup) + + # Restore the backup + logging.info(f"Restoring database from {backup_name}") + shutil.copy2(backup_path, db_path) + + logging.info(f"Database restored from {backup_name}") + return f"Database restored from {backup_name}" + except Exception as e: + error_msg = f"Failed to restore backup: {str(e)}" + logging.error(error_msg) + return error_msg + + +def setup_backup_config(): + """Setup configuration for database backups.""" + backup_base_dir = get_project_relative_path('tldw_DB_Backups') + logging.info(f"Base backup directory: {os.path.abspath(backup_base_dir)}") + + # RAG Chat DB configuration + rag_db_path = get_rag_qa_db_path() + rag_backup_dir = os.path.join(backup_base_dir, 'rag_chat') + os.makedirs(rag_backup_dir, exist_ok=True) + logging.info(f"RAG backup directory: {os.path.abspath(rag_backup_dir)}") + + rag_db_config = { + 'db_path': rag_db_path, + 'backup_dir': rag_backup_dir, # Make sure we use the full path + 'db_name': 'rag_qa' + } + + # Character Chat DB configuration + char_backup_dir = os.path.join(backup_base_dir, 'character_chat') + os.makedirs(char_backup_dir, exist_ok=True) + logging.info(f"Character backup directory: {os.path.abspath(char_backup_dir)}") + + char_db_config = { + 'db_path': chat_DB_PATH, + 'backup_dir': char_backup_dir, # Make sure we use the full path + 'db_name': 'chatDB' + } + + # Media DB configuration (based on your logs) + media_backup_dir = os.path.join(backup_base_dir, 'media') + os.makedirs(media_backup_dir, exist_ok=True) + logging.info(f"Media backup directory: {os.path.abspath(media_backup_dir)}") + + media_db_config = { + 'db_path': os.path.join(os.path.dirname(chat_DB_PATH), 'media_summary.db'), + 'backup_dir': media_backup_dir, + 'db_name': 'media' + } + + return rag_db_config, char_db_config, media_db_config + diff --git a/App_Function_Libraries/DB/DB_Manager.py b/App_Function_Libraries/DB/DB_Manager.py index fc0c7b498..13c3cb15d 100644 --- a/App_Function_Libraries/DB/DB_Manager.py +++ b/App_Function_Libraries/DB/DB_Manager.py @@ -13,11 +13,14 @@ # # Import your existing SQLite functions from App_Function_Libraries.DB.SQLite_DB import DatabaseError +from App_Function_Libraries.DB.Prompts_DB import list_prompts as sqlite_list_prompts, \ + fetch_prompt_details as sqlite_fetch_prompt_details, add_prompt as sqlite_add_prompt, \ + search_prompts as sqlite_search_prompts, add_or_update_prompt as sqlite_add_or_update_prompt, \ + load_prompt_details as sqlite_load_prompt_details, insert_prompt_to_db as sqlite_insert_prompt_to_db, \ + delete_prompt as sqlite_delete_prompt from App_Function_Libraries.DB.SQLite_DB import ( update_media_content as sqlite_update_media_content, - list_prompts as sqlite_list_prompts, search_and_display as sqlite_search_and_display, - fetch_prompt_details as sqlite_fetch_prompt_details, keywords_browser_interface as sqlite_keywords_browser_interface, add_keyword as sqlite_add_keyword, delete_keyword as sqlite_delete_keyword, @@ -25,31 +28,17 @@ ingest_article_to_db as sqlite_ingest_article_to_db, add_media_to_database as sqlite_add_media_to_database, import_obsidian_note_to_db as sqlite_import_obsidian_note_to_db, - add_prompt as sqlite_add_prompt, - delete_chat_message as sqlite_delete_chat_message, - update_chat_message as sqlite_update_chat_message, - add_chat_message as sqlite_add_chat_message, - get_chat_messages as sqlite_get_chat_messages, - search_chat_conversations as sqlite_search_chat_conversations, - create_chat_conversation as sqlite_create_chat_conversation, - save_chat_history_to_database as sqlite_save_chat_history_to_database, view_database as sqlite_view_database, get_transcripts as sqlite_get_transcripts, get_trashed_items as sqlite_get_trashed_items, user_delete_item as sqlite_user_delete_item, empty_trash as sqlite_empty_trash, create_automated_backup as sqlite_create_automated_backup, - add_or_update_prompt as sqlite_add_or_update_prompt, - load_prompt_details as sqlite_load_prompt_details, - load_preset_prompts as sqlite_load_preset_prompts, - insert_prompt_to_db as sqlite_insert_prompt_to_db, - delete_prompt as sqlite_delete_prompt, search_and_display_items as sqlite_search_and_display_items, - get_conversation_name as sqlite_get_conversation_name, add_media_with_keywords as sqlite_add_media_with_keywords, check_media_and_whisper_model as sqlite_check_media_and_whisper_model, \ create_document_version as sqlite_create_document_version, - get_document_version as sqlite_get_document_version, sqlite_search_db, add_media_chunk as sqlite_add_media_chunk, + get_document_version as sqlite_get_document_version, search_media_db as sqlite_search_media_db, add_media_chunk as sqlite_add_media_chunk, sqlite_update_fts_for_media, get_unprocessed_media as sqlite_get_unprocessed_media, fetch_item_details as sqlite_fetch_item_details, \ search_media_database as sqlite_search_media_database, mark_as_trash as sqlite_mark_as_trash, \ get_media_transcripts as sqlite_get_media_transcripts, get_specific_transcript as sqlite_get_specific_transcript, \ @@ -60,23 +49,35 @@ delete_specific_prompt as sqlite_delete_specific_prompt, fetch_keywords_for_media as sqlite_fetch_keywords_for_media, \ update_keywords_for_media as sqlite_update_keywords_for_media, check_media_exists as sqlite_check_media_exists, \ - search_prompts as sqlite_search_prompts, get_media_content as sqlite_get_media_content, \ - get_paginated_files as sqlite_get_paginated_files, get_media_title as sqlite_get_media_title, \ - get_all_content_from_database as sqlite_get_all_content_from_database, - get_next_media_id as sqlite_get_next_media_id, \ - batch_insert_chunks as sqlite_batch_insert_chunks, Database, save_workflow_chat_to_db as sqlite_save_workflow_chat_to_db, \ - get_workflow_chat as sqlite_get_workflow_chat, update_media_content_with_version as sqlite_update_media_content_with_version, \ + get_media_content as sqlite_get_media_content, get_paginated_files as sqlite_get_paginated_files, \ + get_media_title as sqlite_get_media_title, get_all_content_from_database as sqlite_get_all_content_from_database, \ + get_next_media_id as sqlite_get_next_media_id, batch_insert_chunks as sqlite_batch_insert_chunks, Database, \ + save_workflow_chat_to_db as sqlite_save_workflow_chat_to_db, get_workflow_chat as sqlite_get_workflow_chat, \ + update_media_content_with_version as sqlite_update_media_content_with_version, \ check_existing_media as sqlite_check_existing_media, get_all_document_versions as sqlite_get_all_document_versions, \ fetch_paginated_data as sqlite_fetch_paginated_data, get_latest_transcription as sqlite_get_latest_transcription, \ mark_media_as_processed as sqlite_mark_media_as_processed, ) +from App_Function_Libraries.DB.RAG_QA_Chat_DB import start_new_conversation as sqlite_start_new_conversation, \ + save_message as sqlite_save_message, load_chat_history as sqlite_load_chat_history, \ + get_all_conversations as sqlite_get_all_conversations, get_notes_by_keywords as sqlite_get_notes_by_keywords, \ + get_note_by_id as sqlite_get_note_by_id, update_note as sqlite_update_note, save_notes as sqlite_save_notes, \ + clear_keywords_from_note as sqlite_clear_keywords_from_note, add_keywords_to_note as sqlite_add_keywords_to_note, \ + add_keywords_to_conversation as sqlite_add_keywords_to_conversation, \ + get_keywords_for_note as sqlite_get_keywords_for_note, delete_note as sqlite_delete_note, \ + search_conversations_by_keywords as sqlite_search_conversations_by_keywords, \ + delete_conversation as sqlite_delete_conversation, get_conversation_title as sqlite_get_conversation_title, \ + update_conversation_title as sqlite_update_conversation_title, \ + fetch_all_conversations as sqlite_fetch_all_conversations, fetch_all_notes as sqlite_fetch_all_notes, \ + fetch_conversations_by_ids as sqlite_fetch_conversations_by_ids, fetch_notes_by_ids as sqlite_fetch_notes_by_ids, \ + delete_messages_in_conversation as sqlite_delete_messages_in_conversation, \ + get_conversation_text as sqlite_get_conversation_text, search_notes_titles as sqlite_search_notes_titles from App_Function_Libraries.DB.Character_Chat_DB import ( add_character_card as sqlite_add_character_card, get_character_cards as sqlite_get_character_cards, \ get_character_card_by_id as sqlite_get_character_card_by_id, update_character_card as sqlite_update_character_card, \ delete_character_card as sqlite_delete_character_card, add_character_chat as sqlite_add_character_chat, \ get_character_chats as sqlite_get_character_chats, get_character_chat_by_id as sqlite_get_character_chat_by_id, \ - update_character_chat as sqlite_update_character_chat, delete_character_chat as sqlite_delete_character_chat, \ - migrate_chat_to_media_db as sqlite_migrate_chat_to_media_db, + update_character_chat as sqlite_update_character_chat, delete_character_chat as sqlite_delete_character_chat ) # # Local Imports @@ -214,9 +215,9 @@ def ensure_directory_exists(file_path): # # DB Search functions -def search_db(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 10): +def search_media_db(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 10): if db_type == 'sqlite': - return sqlite_search_db(search_query, search_fields, keywords, page, results_per_page) + return sqlite_search_media_db(search_query, search_fields, keywords, page, results_per_page) elif db_type == 'elasticsearch': # Implement Elasticsearch version when available raise NotImplementedError("Elasticsearch version of search_db not yet implemented") @@ -500,13 +501,6 @@ def load_prompt_details(*args, **kwargs): # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def load_preset_prompts(*args, **kwargs): - if db_type == 'sqlite': - return sqlite_load_preset_prompts() - elif db_type == 'elasticsearch': - # Implement Elasticsearch version - raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") - def insert_prompt_to_db(*args, **kwargs): if db_type == 'sqlite': return sqlite_insert_prompt_to_db(*args, **kwargs) @@ -539,7 +533,6 @@ def mark_as_trash(media_id: int) -> None: else: raise ValueError(f"Unsupported database type: {db_type}") - def get_latest_transcription(*args, **kwargs): if db_type == 'sqlite': return sqlite_get_latest_transcription(*args, **kwargs) @@ -721,62 +714,132 @@ def fetch_keywords_for_media(*args, **kwargs): # # Chat-related Functions -def delete_chat_message(*args, **kwargs): +def search_notes_titles(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_search_notes_titles(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def save_message(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_save_message(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def load_chat_history(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_load_chat_history(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def start_new_conversation(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_start_new_conversation(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def get_all_conversations(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_get_all_conversations(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def get_notes_by_keywords(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_get_notes_by_keywords(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def get_note_by_id(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_get_note_by_id(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def add_keywords_to_conversation(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_add_keywords_to_conversation(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def get_keywords_for_note(*args, **kwargs): if db_type == 'sqlite': - return sqlite_delete_chat_message(*args, **kwargs) + return sqlite_get_keywords_for_note(*args, **kwargs) elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def update_chat_message(*args, **kwargs): +def delete_note(*args, **kwargs): if db_type == 'sqlite': - return sqlite_update_chat_message(*args, **kwargs) + return sqlite_delete_note(*args, **kwargs) elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def add_chat_message(*args, **kwargs): +def search_conversations_by_keywords(*args, **kwargs): if db_type == 'sqlite': - return sqlite_add_chat_message(*args, **kwargs) + return sqlite_search_conversations_by_keywords(*args, **kwargs) elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def get_chat_messages(*args, **kwargs): +def delete_conversation(*args, **kwargs): if db_type == 'sqlite': - return sqlite_get_chat_messages(*args, **kwargs) + return sqlite_delete_conversation(*args, **kwargs) elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def search_chat_conversations(*args, **kwargs): +def get_conversation_title(*args, **kwargs): if db_type == 'sqlite': - return sqlite_search_chat_conversations(*args, **kwargs) + return sqlite_get_conversation_title(*args, **kwargs) elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def create_chat_conversation(*args, **kwargs): +def update_conversation_title(*args, **kwargs): if db_type == 'sqlite': - return sqlite_create_chat_conversation(*args, **kwargs) + return sqlite_update_conversation_title(*args, **kwargs) elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def save_chat_history_to_database(*args, **kwargs): +def fetch_all_conversations(*args, **kwargs): if db_type == 'sqlite': - return sqlite_save_chat_history_to_database(*args, **kwargs) + return sqlite_fetch_all_conversations() elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def get_conversation_name(*args, **kwargs): +def fetch_all_notes(*args, **kwargs): if db_type == 'sqlite': - return sqlite_get_conversation_name(*args, **kwargs) + return sqlite_fetch_all_notes() elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") +def delete_messages_in_conversation(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_delete_messages_in_conversation(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of delete_messages_in_conversation not yet implemented") + +def get_conversation_text(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_get_conversation_text(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of get_conversation_text not yet implemented") + # # End of Chat-related Functions ############################################################################################################ @@ -856,12 +919,54 @@ def delete_character_chat(*args, **kwargs): # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of delete_character_chat not yet implemented") -def migrate_chat_to_media_db(*args, **kwargs): +def update_note(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_update_note(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of update_note not yet implemented") + +def save_notes(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_save_notes(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of save_notes not yet implemented") + +def clear_keywords(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_clear_keywords_from_note(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of clear_keywords not yet implemented") + +def clear_keywords_from_note(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_clear_keywords_from_note(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of clear_keywords_from_note not yet implemented") + +def add_keywords_to_note(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_add_keywords_to_note(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_keywords_to_note not yet implemented") + +def fetch_conversations_by_ids(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_fetch_conversations_by_ids(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of fetch_conversations_by_ids not yet implemented") + +def fetch_notes_by_ids(*args, **kwargs): if db_type == 'sqlite': - return sqlite_migrate_chat_to_media_db(*args, **kwargs) + return sqlite_fetch_notes_by_ids(*args, **kwargs) elif db_type == 'elasticsearch': # Implement Elasticsearch version - raise NotImplementedError("Elasticsearch version of migrate_chat_to_media_db not yet implemented") + raise NotImplementedError("Elasticsearch version of fetch_notes_by_ids not yet implemented") # # End of Character Chat-related Functions diff --git a/App_Function_Libraries/DB/Prompts_DB.py b/App_Function_Libraries/DB/Prompts_DB.py new file mode 100644 index 000000000..68d8eae55 --- /dev/null +++ b/App_Function_Libraries/DB/Prompts_DB.py @@ -0,0 +1,626 @@ +# Prompts_DB.py +# Description: Functions to manage the prompts database. +# +# Imports +import sqlite3 +import logging +# +# External Imports +import re +from typing import Tuple +# +# Local Imports +from App_Function_Libraries.Utils.Utils import get_database_path +# +####################################################################################################################### +# +# Functions to manage prompts DB + +def create_prompts_db(): + logging.debug("create_prompts_db: Creating prompts database.") + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + cursor.executescript(''' + CREATE TABLE IF NOT EXISTS Prompts ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + author TEXT, + details TEXT, + system TEXT, + user TEXT + ); + CREATE TABLE IF NOT EXISTS Keywords ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + keyword TEXT NOT NULL UNIQUE COLLATE NOCASE + ); + CREATE TABLE IF NOT EXISTS PromptKeywords ( + prompt_id INTEGER, + keyword_id INTEGER, + FOREIGN KEY (prompt_id) REFERENCES Prompts (id), + FOREIGN KEY (keyword_id) REFERENCES Keywords (id), + PRIMARY KEY (prompt_id, keyword_id) + ); + CREATE INDEX IF NOT EXISTS idx_keywords_keyword ON Keywords(keyword); + CREATE INDEX IF NOT EXISTS idx_promptkeywords_prompt_id ON PromptKeywords(prompt_id); + CREATE INDEX IF NOT EXISTS idx_promptkeywords_keyword_id ON PromptKeywords(keyword_id); + ''') + +# FIXME - dirty hack that should be removed later... +# Migration function to add the 'author' column to the Prompts table +def add_author_column_to_prompts(): + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + # Check if 'author' column already exists + cursor.execute("PRAGMA table_info(Prompts)") + columns = [col[1] for col in cursor.fetchall()] + + if 'author' not in columns: + # Add the 'author' column + cursor.execute('ALTER TABLE Prompts ADD COLUMN author TEXT') + print("Author column added to Prompts table.") + else: + print("Author column already exists in Prompts table.") + +add_author_column_to_prompts() + +def normalize_keyword(keyword): + return re.sub(r'\s+', ' ', keyword.strip().lower()) + + +# FIXME - update calls to this function to use the new args +def add_prompt(name, author, details, system=None, user=None, keywords=None): + logging.debug(f"add_prompt: Adding prompt with name: {name}, author: {author}, system: {system}, user: {user}, keywords: {keywords}") + if not name: + logging.error("add_prompt: A name is required.") + return "A name is required." + + try: + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + cursor.execute(''' + INSERT INTO Prompts (name, author, details, system, user) + VALUES (?, ?, ?, ?, ?) + ''', (name, author, details, system, user)) + prompt_id = cursor.lastrowid + + if keywords: + normalized_keywords = [normalize_keyword(k) for k in keywords if k.strip()] + for keyword in set(normalized_keywords): # Use set to remove duplicates + cursor.execute(''' + INSERT OR IGNORE INTO Keywords (keyword) VALUES (?) + ''', (keyword,)) + cursor.execute('SELECT id FROM Keywords WHERE keyword = ?', (keyword,)) + keyword_id = cursor.fetchone()[0] + cursor.execute(''' + INSERT OR IGNORE INTO PromptKeywords (prompt_id, keyword_id) VALUES (?, ?) + ''', (prompt_id, keyword_id)) + return "Prompt added successfully." + except sqlite3.IntegrityError: + return "Prompt with this name already exists." + except sqlite3.Error as e: + return f"Database error: {e}" + + +def fetch_prompt_details(name): + logging.debug(f"fetch_prompt_details: Fetching details for prompt: {name}") + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + cursor.execute(''' + SELECT p.name, p.author, p.details, p.system, p.user, GROUP_CONCAT(k.keyword, ', ') as keywords + FROM Prompts p + LEFT JOIN PromptKeywords pk ON p.id = pk.prompt_id + LEFT JOIN Keywords k ON pk.keyword_id = k.id + WHERE p.name = ? + GROUP BY p.id + ''', (name,)) + return cursor.fetchone() + + +def list_prompts(page=1, per_page=10): + logging.debug(f"list_prompts: Listing prompts for page {page} with {per_page} prompts per page.") + offset = (page - 1) * per_page + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + cursor.execute('SELECT name FROM Prompts LIMIT ? OFFSET ?', (per_page, offset)) + prompts = [row[0] for row in cursor.fetchall()] + + # Get total count of prompts + cursor.execute('SELECT COUNT(*) FROM Prompts') + total_count = cursor.fetchone()[0] + + total_pages = (total_count + per_page - 1) // per_page + return prompts, total_pages, page + + +def insert_prompt_to_db(title, author, description, system_prompt, user_prompt, keywords=None): + return add_prompt(title, author, description, system_prompt, user_prompt, keywords) + + +def get_prompt_db_connection(): + prompt_db_path = get_database_path('prompts.db') + return sqlite3.connect(prompt_db_path) + + +def search_prompts(query): + logging.debug(f"search_prompts: Searching prompts with query: {query}") + try: + with get_prompt_db_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT p.name, p.details, p.system, p.user, GROUP_CONCAT(k.keyword, ', ') as keywords + FROM Prompts p + LEFT JOIN PromptKeywords pk ON p.id = pk.prompt_id + LEFT JOIN Keywords k ON pk.keyword_id = k.id + WHERE p.name LIKE ? OR p.details LIKE ? OR p.system LIKE ? OR p.user LIKE ? OR k.keyword LIKE ? + GROUP BY p.id + ORDER BY p.name + """, (f'%{query}%', f'%{query}%', f'%{query}%', f'%{query}%', f'%{query}%')) + return cursor.fetchall() + except sqlite3.Error as e: + logging.error(f"Error searching prompts: {e}") + return [] + + +def search_prompts_by_keyword(keyword, page=1, per_page=10): + logging.debug(f"search_prompts_by_keyword: Searching prompts by keyword: {keyword}") + normalized_keyword = normalize_keyword(keyword) + offset = (page - 1) * per_page + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + cursor.execute(''' + SELECT DISTINCT p.name + FROM Prompts p + JOIN PromptKeywords pk ON p.id = pk.prompt_id + JOIN Keywords k ON pk.keyword_id = k.id + WHERE k.keyword LIKE ? + LIMIT ? OFFSET ? + ''', ('%' + normalized_keyword + '%', per_page, offset)) + prompts = [row[0] for row in cursor.fetchall()] + + # Get total count of matching prompts + cursor.execute(''' + SELECT COUNT(DISTINCT p.id) + FROM Prompts p + JOIN PromptKeywords pk ON p.id = pk.prompt_id + JOIN Keywords k ON pk.keyword_id = k.id + WHERE k.keyword LIKE ? + ''', ('%' + normalized_keyword + '%',)) + total_count = cursor.fetchone()[0] + + total_pages = (total_count + per_page - 1) // per_page + return prompts, total_pages, page + + +def update_prompt_keywords(prompt_name, new_keywords): + logging.debug(f"update_prompt_keywords: Updating keywords for prompt: {prompt_name}") + try: + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + + cursor.execute('SELECT id FROM Prompts WHERE name = ?', (prompt_name,)) + prompt_id = cursor.fetchone() + if not prompt_id: + return "Prompt not found." + prompt_id = prompt_id[0] + + cursor.execute('DELETE FROM PromptKeywords WHERE prompt_id = ?', (prompt_id,)) + + normalized_keywords = [normalize_keyword(k) for k in new_keywords if k.strip()] + for keyword in set(normalized_keywords): # Use set to remove duplicates + cursor.execute('INSERT OR IGNORE INTO Keywords (keyword) VALUES (?)', (keyword,)) + cursor.execute('SELECT id FROM Keywords WHERE keyword = ?', (keyword,)) + keyword_id = cursor.fetchone()[0] + cursor.execute('INSERT INTO PromptKeywords (prompt_id, keyword_id) VALUES (?, ?)', + (prompt_id, keyword_id)) + + # Remove unused keywords + cursor.execute(''' + DELETE FROM Keywords + WHERE id NOT IN (SELECT DISTINCT keyword_id FROM PromptKeywords) + ''') + return "Keywords updated successfully." + except sqlite3.Error as e: + return f"Database error: {e}" + + +def add_or_update_prompt(title, author, description, system_prompt, user_prompt, keywords=None): + logging.debug(f"add_or_update_prompt: Adding or updating prompt: {title}") + if not title: + return "Error: Title is required." + + existing_prompt = fetch_prompt_details(title) + if existing_prompt: + # Update existing prompt + result = update_prompt_in_db(title, author, description, system_prompt, user_prompt) + if "successfully" in result: + # Update keywords if the prompt update was successful + keyword_result = update_prompt_keywords(title, keywords or []) + result += f" {keyword_result}" + else: + # Insert new prompt + result = insert_prompt_to_db(title, author, description, system_prompt, user_prompt, keywords) + + return result + + +def load_prompt_details(selected_prompt): + logging.debug(f"load_prompt_details: Loading prompt details for {selected_prompt}") + if selected_prompt: + details = fetch_prompt_details(selected_prompt) + if details: + return details[0], details[1], details[2], details[3], details[4], details[5] + return "", "", "", "", "", "" + + +def update_prompt_in_db(title, author, description, system_prompt, user_prompt): + logging.debug(f"update_prompt_in_db: Updating prompt: {title}") + try: + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE Prompts SET author = ?, details = ?, system = ?, user = ? WHERE name = ?", + (author, description, system_prompt, user_prompt, title) + ) + if cursor.rowcount == 0: + return "No prompt found with the given title." + return "Prompt updated successfully!" + except sqlite3.Error as e: + return f"Error updating prompt: {e}" + + +def delete_prompt(prompt_id): + logging.debug(f"delete_prompt: Deleting prompt with ID: {prompt_id}") + try: + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + + # Delete associated keywords + cursor.execute("DELETE FROM PromptKeywords WHERE prompt_id = ?", (prompt_id,)) + + # Delete the prompt + cursor.execute("DELETE FROM Prompts WHERE id = ?", (prompt_id,)) + + if cursor.rowcount == 0: + return f"No prompt found with ID {prompt_id}" + else: + conn.commit() + return f"Prompt with ID {prompt_id} has been successfully deleted" + except sqlite3.Error as e: + return f"An error occurred: {e}" + + +def delete_prompt_keyword(keyword: str) -> str: + """ + Delete a keyword and its associations from the prompts database. + + Args: + keyword (str): The keyword to delete + + Returns: + str: Success/failure message + """ + logging.debug(f"delete_prompt_keyword: Deleting keyword: {keyword}") + try: + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + + # First normalize the keyword + normalized_keyword = normalize_keyword(keyword) + + # Get the keyword ID + cursor.execute("SELECT id FROM Keywords WHERE keyword = ?", (normalized_keyword,)) + result = cursor.fetchone() + + if not result: + return f"Keyword '{keyword}' not found." + + keyword_id = result[0] + + # Delete keyword associations from PromptKeywords + cursor.execute("DELETE FROM PromptKeywords WHERE keyword_id = ?", (keyword_id,)) + + # Delete the keyword itself + cursor.execute("DELETE FROM Keywords WHERE id = ?", (keyword_id,)) + + # Get the number of affected prompts + affected_prompts = cursor.rowcount + + conn.commit() + + logging.info(f"Keyword '{keyword}' deleted successfully") + return f"Successfully deleted keyword '{keyword}' and removed it from {affected_prompts} prompts." + + except sqlite3.Error as e: + error_msg = f"Database error deleting keyword: {str(e)}" + logging.error(error_msg) + return error_msg + except Exception as e: + error_msg = f"Error deleting keyword: {str(e)}" + logging.error(error_msg) + return error_msg + + +def export_prompt_keywords_to_csv() -> Tuple[str, str]: + """ + Export all prompt keywords to a CSV file with associated metadata. + + Returns: + Tuple[str, str]: (status_message, file_path) + """ + import csv + import tempfile + import os + from datetime import datetime + + logging.debug("export_prompt_keywords_to_csv: Starting export") + try: + # Create a temporary file with a specific name in the system's temp directory + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + temp_dir = tempfile.gettempdir() + file_path = os.path.join(temp_dir, f'prompt_keywords_export_{timestamp}.csv') + + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + + # Get keywords with related prompt information + query = ''' + SELECT + k.keyword, + GROUP_CONCAT(p.name, ' | ') as prompt_names, + COUNT(DISTINCT p.id) as num_prompts, + GROUP_CONCAT(DISTINCT p.author, ' | ') as authors + FROM Keywords k + LEFT JOIN PromptKeywords pk ON k.id = pk.keyword_id + LEFT JOIN Prompts p ON pk.prompt_id = p.id + GROUP BY k.id, k.keyword + ORDER BY k.keyword + ''' + + cursor.execute(query) + results = cursor.fetchall() + + # Write to CSV + with open(file_path, 'w', newline='', encoding='utf-8') as csvfile: + writer = csv.writer(csvfile) + writer.writerow([ + 'Keyword', + 'Associated Prompts', + 'Number of Prompts', + 'Authors' + ]) + + for row in results: + writer.writerow([ + row[0], # keyword + row[1] if row[1] else '', # prompt_names (may be None) + row[2], # num_prompts + row[3] if row[3] else '' # authors (may be None) + ]) + + status_msg = f"Successfully exported {len(results)} prompt keywords to CSV." + logging.info(status_msg) + + return status_msg, file_path + + except sqlite3.Error as e: + error_msg = f"Database error exporting keywords: {str(e)}" + logging.error(error_msg) + return error_msg, "None" + except Exception as e: + error_msg = f"Error exporting keywords: {str(e)}" + logging.error(error_msg) + return error_msg, "None" + + +def view_prompt_keywords() -> str: + """ + View all keywords currently in the prompts database. + + Returns: + str: Markdown formatted string of all keywords + """ + logging.debug("view_prompt_keywords: Retrieving all keywords") + try: + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT k.keyword, COUNT(DISTINCT pk.prompt_id) as prompt_count + FROM Keywords k + LEFT JOIN PromptKeywords pk ON k.id = pk.keyword_id + GROUP BY k.id, k.keyword + ORDER BY k.keyword + """) + + keywords = cursor.fetchall() + if keywords: + keyword_list = [f"- {k[0]} ({k[1]} prompts)" for k in keywords] + return "### Current Prompt Keywords:\n" + "\n".join(keyword_list) + return "No keywords found." + + except Exception as e: + error_msg = f"Error retrieving keywords: {str(e)}" + logging.error(error_msg) + return error_msg + + +def export_prompts( + export_format='csv', + filter_keywords=None, + include_system=True, + include_user=True, + include_details=True, + include_author=True, + include_keywords=True, + markdown_template=None +) -> Tuple[str, str]: + """ + Export prompts to CSV or Markdown with configurable options. + + Args: + export_format (str): 'csv' or 'markdown' + filter_keywords (List[str], optional): Keywords to filter prompts by + include_system (bool): Include system prompts in export + include_user (bool): Include user prompts in export + include_details (bool): Include prompt details/descriptions + include_author (bool): Include author information + include_keywords (bool): Include associated keywords + markdown_template (str, optional): Template for markdown export + + Returns: + Tuple[str, str]: (status_message, file_path) + """ + import csv + import tempfile + import os + import zipfile + from datetime import datetime + + try: + # Get prompts data + with get_prompt_db_connection() as conn: + cursor = conn.cursor() + + # Build query based on included fields + select_fields = ['p.name'] + if include_author: + select_fields.append('p.author') + if include_details: + select_fields.append('p.details') + if include_system: + select_fields.append('p.system') + if include_user: + select_fields.append('p.user') + + query = f""" + SELECT DISTINCT {', '.join(select_fields)} + FROM Prompts p + """ + + # Add keyword filtering if specified + if filter_keywords: + placeholders = ','.join(['?' for _ in filter_keywords]) + query += f""" + JOIN PromptKeywords pk ON p.id = pk.prompt_id + JOIN Keywords k ON pk.keyword_id = k.id + WHERE k.keyword IN ({placeholders}) + """ + + cursor.execute(query, filter_keywords if filter_keywords else ()) + prompts = cursor.fetchall() + + # Get keywords for each prompt if needed + if include_keywords: + prompt_keywords = {} + for prompt in prompts: + cursor.execute(""" + SELECT k.keyword + FROM Keywords k + JOIN PromptKeywords pk ON k.id = pk.keyword_id + JOIN Prompts p ON pk.prompt_id = p.id + WHERE p.name = ? + """, (prompt[0],)) + prompt_keywords[prompt[0]] = [row[0] for row in cursor.fetchall()] + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + if export_format == 'csv': + # Export as CSV + temp_file = os.path.join(tempfile.gettempdir(), f'prompts_export_{timestamp}.csv') + with open(temp_file, 'w', newline='', encoding='utf-8') as csvfile: + writer = csv.writer(csvfile) + + # Write header + header = ['Name'] + if include_author: + header.append('Author') + if include_details: + header.append('Details') + if include_system: + header.append('System Prompt') + if include_user: + header.append('User Prompt') + if include_keywords: + header.append('Keywords') + writer.writerow(header) + + # Write data + for prompt in prompts: + row = list(prompt) + if include_keywords: + row.append(', '.join(prompt_keywords.get(prompt[0], []))) + writer.writerow(row) + + return f"Successfully exported {len(prompts)} prompts to CSV.", temp_file + + else: + # Export as Markdown files in ZIP + temp_dir = tempfile.mkdtemp() + zip_path = os.path.join(tempfile.gettempdir(), f'prompts_export_{timestamp}.zip') + + # Define markdown templates + templates = { + "Basic Template": """# {title} +{author_section} +{details_section} +{system_section} +{user_section} +{keywords_section} +""", + "Detailed Template": """# {title} + +## Author +{author_section} + +## Description +{details_section} + +## System Prompt +{system_section} + +## User Prompt +{user_section} + +## Keywords +{keywords_section} +""" + } + + template = templates.get(markdown_template, markdown_template or templates["Basic Template"]) + + with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: + for prompt in prompts: + # Create markdown content + md_content = template.format( + title=prompt[0], + author_section=f"Author: {prompt[1]}" if include_author else "", + details_section=prompt[2] if include_details else "", + system_section=prompt[3] if include_system else "", + user_section=prompt[4] if include_user else "", + keywords_section=', '.join(prompt_keywords.get(prompt[0], [])) if include_keywords else "" + ) + + # Create safe filename + safe_filename = re.sub(r'[^\w\-_\. ]', '_', prompt[0]) + md_path = os.path.join(temp_dir, f"{safe_filename}.md") + + # Write markdown file + with open(md_path, 'w', encoding='utf-8') as f: + f.write(md_content) + + # Add to ZIP + zipf.write(md_path, os.path.basename(md_path)) + + return f"Successfully exported {len(prompts)} prompts to Markdown files.", zip_path + + except Exception as e: + error_msg = f"Error exporting prompts: {str(e)}" + logging.error(error_msg) + return error_msg, "None" + + +create_prompts_db() + +# +# End of Propmts_DB.py +####################################################################################################################### + diff --git a/App_Function_Libraries/DB/RAG_QA_Chat_DB.py b/App_Function_Libraries/DB/RAG_QA_Chat_DB.py index f4df79894..35db2b3ba 100644 --- a/App_Function_Libraries/DB/RAG_QA_Chat_DB.py +++ b/App_Function_Libraries/DB/RAG_QA_Chat_DB.py @@ -4,39 +4,37 @@ # Imports import configparser import logging +import os import re import sqlite3 import uuid from contextlib import contextmanager from datetime import datetime - -from App_Function_Libraries.Utils.Utils import get_project_relative_path, get_database_path - +from pathlib import Path +from typing import List, Dict, Any, Tuple, Optional # # External Imports # (No external imports) # # Local Imports -# (No additional local imports) +from App_Function_Libraries.Utils.Utils import get_project_relative_path, get_project_root + # ######################################################################################################################## # # Functions: -# Construct the path to the config file -config_path = get_project_relative_path('Config_Files/config.txt') - -# Read the config file -config = configparser.ConfigParser() -config.read(config_path) - -# Get the SQLite path from the config, or use the default if not specified -if config.has_section('Database') and config.has_option('Database', 'rag_qa_db_path'): - rag_qa_db_path = config.get('Database', 'rag_qa_db_path') -else: - rag_qa_db_path = get_database_path('RAG_QA_Chat.db') - -print(f"RAG QA Chat Database path: {rag_qa_db_path}") +def get_rag_qa_db_path(): + config_path = os.path.join(get_project_root(), 'Config_Files', 'config.txt') + config = configparser.ConfigParser() + config.read(config_path) + if config.has_section('Database') and config.has_option('Database', 'rag_qa_db_path'): + rag_qa_db_path = config.get('Database', 'rag_qa_db_path') + if not os.path.isabs(rag_qa_db_path): + rag_qa_db_path = get_project_relative_path(rag_qa_db_path) + return rag_qa_db_path + else: + raise ValueError("Database path not found in config file") # Set up logging logging.basicConfig(level=logging.INFO) @@ -58,7 +56,9 @@ conversation_id TEXT PRIMARY KEY, created_at DATETIME NOT NULL, last_updated DATETIME NOT NULL, - title TEXT NOT NULL + title TEXT NOT NULL, + media_id INTEGER, + rating INTEGER CHECK(rating BETWEEN 1 AND 3) ); -- Table for storing keywords @@ -122,19 +122,137 @@ CREATE INDEX IF NOT EXISTS idx_rag_qa_collection_keywords_collection_id ON rag_qa_collection_keywords(collection_id); CREATE INDEX IF NOT EXISTS idx_rag_qa_collection_keywords_keyword_id ON rag_qa_collection_keywords(keyword_id); --- Full-text search virtual table for chat content -CREATE VIRTUAL TABLE IF NOT EXISTS rag_qa_chats_fts USING fts5(conversation_id, timestamp, role, content); +-- Full-text search virtual tables +CREATE VIRTUAL TABLE IF NOT EXISTS rag_qa_chats_fts USING fts5( + content, + content='rag_qa_chats', + content_rowid='id' +); + +-- FTS table for conversation metadata +CREATE VIRTUAL TABLE IF NOT EXISTS conversation_metadata_fts USING fts5( + title, + content='conversation_metadata', + content_rowid='rowid' +); + +-- FTS table for keywords +CREATE VIRTUAL TABLE IF NOT EXISTS rag_qa_keywords_fts USING fts5( + keyword, + content='rag_qa_keywords', + content_rowid='id' +); + +-- FTS table for keyword collections +CREATE VIRTUAL TABLE IF NOT EXISTS rag_qa_keyword_collections_fts USING fts5( + name, + content='rag_qa_keyword_collections', + content_rowid='id' +); + +-- FTS table for notes +CREATE VIRTUAL TABLE IF NOT EXISTS rag_qa_notes_fts USING fts5( + title, + content, + content='rag_qa_notes', + content_rowid='id' +); +-- FTS table for notes (modified to include both title and content) +CREATE VIRTUAL TABLE IF NOT EXISTS rag_qa_notes_fts USING fts5( + title, + content, + content='rag_qa_notes', + content_rowid='id' +); --- Trigger to keep the FTS table up to date +-- Triggers for maintaining FTS indexes +-- Triggers for rag_qa_chats CREATE TRIGGER IF NOT EXISTS rag_qa_chats_ai AFTER INSERT ON rag_qa_chats BEGIN - INSERT INTO rag_qa_chats_fts(conversation_id, timestamp, role, content) VALUES (new.conversation_id, new.timestamp, new.role, new.content); + INSERT INTO rag_qa_chats_fts(rowid, content) + VALUES (new.id, new.content); +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_chats_au AFTER UPDATE ON rag_qa_chats BEGIN + UPDATE rag_qa_chats_fts + SET content = new.content + WHERE rowid = old.id; +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_chats_ad AFTER DELETE ON rag_qa_chats BEGIN + DELETE FROM rag_qa_chats_fts WHERE rowid = old.id; +END; + +-- Triggers for conversation_metadata +CREATE TRIGGER IF NOT EXISTS conversation_metadata_ai AFTER INSERT ON conversation_metadata BEGIN + INSERT INTO conversation_metadata_fts(rowid, title) + VALUES (new.rowid, new.title); +END; + +CREATE TRIGGER IF NOT EXISTS conversation_metadata_au AFTER UPDATE ON conversation_metadata BEGIN + UPDATE conversation_metadata_fts + SET title = new.title + WHERE rowid = old.rowid; +END; + +CREATE TRIGGER IF NOT EXISTS conversation_metadata_ad AFTER DELETE ON conversation_metadata BEGIN + DELETE FROM conversation_metadata_fts WHERE rowid = old.rowid; +END; + +-- Triggers for rag_qa_keywords +CREATE TRIGGER IF NOT EXISTS rag_qa_keywords_ai AFTER INSERT ON rag_qa_keywords BEGIN + INSERT INTO rag_qa_keywords_fts(rowid, keyword) + VALUES (new.id, new.keyword); +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_keywords_au AFTER UPDATE ON rag_qa_keywords BEGIN + UPDATE rag_qa_keywords_fts + SET keyword = new.keyword + WHERE rowid = old.id; +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_keywords_ad AFTER DELETE ON rag_qa_keywords BEGIN + DELETE FROM rag_qa_keywords_fts WHERE rowid = old.id; +END; + +-- Triggers for rag_qa_keyword_collections +CREATE TRIGGER IF NOT EXISTS rag_qa_keyword_collections_ai AFTER INSERT ON rag_qa_keyword_collections BEGIN + INSERT INTO rag_qa_keyword_collections_fts(rowid, name) + VALUES (new.id, new.name); +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_keyword_collections_au AFTER UPDATE ON rag_qa_keyword_collections BEGIN + UPDATE rag_qa_keyword_collections_fts + SET name = new.name + WHERE rowid = old.id; +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_keyword_collections_ad AFTER DELETE ON rag_qa_keyword_collections BEGIN + DELETE FROM rag_qa_keyword_collections_fts WHERE rowid = old.id; +END; + +-- Triggers for rag_qa_notes +CREATE TRIGGER IF NOT EXISTS rag_qa_notes_ai AFTER INSERT ON rag_qa_notes BEGIN + INSERT INTO rag_qa_notes_fts(rowid, title, content) + VALUES (new.id, new.title, new.content); +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_notes_au AFTER UPDATE ON rag_qa_notes BEGIN + UPDATE rag_qa_notes_fts + SET title = new.title, + content = new.content + WHERE rowid = old.id; +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_notes_ad AFTER DELETE ON rag_qa_notes BEGIN + DELETE FROM rag_qa_notes_fts WHERE rowid = old.id; END; ''' # Database connection management @contextmanager def get_db_connection(): - conn = sqlite3.connect(rag_qa_db_path) + db_path = get_rag_qa_db_path() + conn = sqlite3.connect(db_path) try: yield conn finally: @@ -168,10 +286,43 @@ def execute_query(query, params=None, conn=None): conn.commit() return cursor.fetchall() + def create_tables(): + """Create database tables and initialize FTS indexes.""" with get_db_connection() as conn: - conn.executescript(SCHEMA_SQL) - logger.info("All RAG QA Chat tables created successfully") + cursor = conn.cursor() + # Execute the SCHEMA_SQL to create tables and triggers + cursor.executescript(SCHEMA_SQL) + + # Check and populate all FTS tables + fts_tables = [ + ('rag_qa_notes_fts', 'rag_qa_notes', ['title', 'content']), + ('rag_qa_chats_fts', 'rag_qa_chats', ['content']), + ('conversation_metadata_fts', 'conversation_metadata', ['title']), + ('rag_qa_keywords_fts', 'rag_qa_keywords', ['keyword']), + ('rag_qa_keyword_collections_fts', 'rag_qa_keyword_collections', ['name']) + ] + + for fts_table, source_table, columns in fts_tables: + # Check if FTS table needs population + cursor.execute(f"SELECT COUNT(*) FROM {fts_table}") + fts_count = cursor.fetchone()[0] + cursor.execute(f"SELECT COUNT(*) FROM {source_table}") + source_count = cursor.fetchone()[0] + + if fts_count != source_count: + # Repopulate FTS table + logger.info(f"Repopulating {fts_table}") + cursor.execute(f"DELETE FROM {fts_table}") + columns_str = ', '.join(columns) + source_columns = ', '.join([f"id" if source_table != 'conversation_metadata' else "rowid"] + columns) + cursor.execute(f""" + INSERT INTO {fts_table}(rowid, {columns_str}) + SELECT {source_columns} FROM {source_table} + """) + + logger.info("All RAG QA Chat tables and triggers created successfully") + # Initialize the database create_tables() @@ -197,6 +348,7 @@ def validate_keyword(keyword): raise ValueError("Keyword contains invalid characters") return keyword.strip() + def validate_collection_name(name): if not isinstance(name, str): raise ValueError("Collection name must be a string") @@ -208,6 +360,7 @@ def validate_collection_name(name): raise ValueError("Collection name contains invalid characters") return name.strip() + # Core functions def add_keyword(keyword, conn=None): try: @@ -222,6 +375,7 @@ def add_keyword(keyword, conn=None): logger.error(f"Error adding keyword '{keyword}': {e}") raise + def create_keyword_collection(name, parent_id=None): try: validated_name = validate_collection_name(name) @@ -235,6 +389,7 @@ def create_keyword_collection(name, parent_id=None): logger.error(f"Error creating keyword collection '{name}': {e}") raise + def add_keyword_to_collection(collection_name, keyword): try: validated_collection_name = validate_collection_name(collection_name) @@ -259,6 +414,7 @@ def add_keyword_to_collection(collection_name, keyword): logger.error(f"Error adding keyword '{keyword}' to collection '{collection_name}': {e}") raise + def add_keywords_to_conversation(conversation_id, keywords): if not isinstance(keywords, (list, tuple)): raise ValueError("Keywords must be a list or tuple") @@ -282,6 +438,23 @@ def add_keywords_to_conversation(conversation_id, keywords): logger.error(f"Error adding keywords to conversation '{conversation_id}': {e}") raise + +def view_rag_keywords(): + try: + rag_db_path = get_rag_qa_db_path() + with sqlite3.connect(rag_db_path) as conn: + cursor = conn.cursor() + cursor.execute("SELECT keyword FROM rag_qa_keywords ORDER BY keyword") + keywords = cursor.fetchall() + if keywords: + keyword_list = [k[0] for k in keywords] + return "### Current RAG QA Keywords:\n" + "\n".join( + [f"- {k}" for k in keyword_list]) + return "No keywords found." + except Exception as e: + return f"Error retrieving keywords: {str(e)}" + + def get_keywords_for_conversation(conversation_id): try: query = ''' @@ -298,6 +471,7 @@ def get_keywords_for_conversation(conversation_id): logger.error(f"Error getting keywords for conversation '{conversation_id}': {e}") raise + def get_keywords_for_collection(collection_name): try: query = ''' @@ -315,6 +489,116 @@ def get_keywords_for_collection(collection_name): logger.error(f"Error getting keywords for collection '{collection_name}': {e}") raise + +def delete_rag_keyword(keyword: str) -> str: + """ + Delete a keyword from the RAG QA database and all its associations. + + Args: + keyword (str): The keyword to delete + + Returns: + str: Success/failure message + """ + try: + # Validate the keyword + validated_keyword = validate_keyword(keyword) + + with transaction() as conn: + # First, get the keyword ID + cursor = conn.cursor() + cursor.execute("SELECT id FROM rag_qa_keywords WHERE keyword = ?", (validated_keyword,)) + result = cursor.fetchone() + + if not result: + return f"Keyword '{validated_keyword}' not found." + + keyword_id = result[0] + + # Delete from all associated tables + cursor.execute("DELETE FROM rag_qa_conversation_keywords WHERE keyword_id = ?", (keyword_id,)) + cursor.execute("DELETE FROM rag_qa_collection_keywords WHERE keyword_id = ?", (keyword_id,)) + cursor.execute("DELETE FROM rag_qa_note_keywords WHERE keyword_id = ?", (keyword_id,)) + + # Finally, delete the keyword itself + cursor.execute("DELETE FROM rag_qa_keywords WHERE id = ?", (keyword_id,)) + + logger.info(f"Keyword '{validated_keyword}' deleted successfully") + return f"Successfully deleted keyword '{validated_keyword}' and all its associations." + + except ValueError as e: + error_msg = f"Invalid keyword: {str(e)}" + logger.error(error_msg) + return error_msg + except Exception as e: + error_msg = f"Error deleting keyword: {str(e)}" + logger.error(error_msg) + return error_msg + + +def export_rag_keywords_to_csv() -> Tuple[str, str]: + """ + Export all RAG QA keywords to a CSV file. + + Returns: + Tuple[str, str]: (status_message, file_path) + """ + import csv + from tempfile import NamedTemporaryFile + from datetime import datetime + + try: + # Create a temporary CSV file + temp_file = NamedTemporaryFile(mode='w+', delete=False, suffix='.csv', newline='') + + with transaction() as conn: + cursor = conn.cursor() + + # Get all keywords and their associations + query = """ + SELECT + k.keyword, + GROUP_CONCAT(DISTINCT c.name) as collections, + COUNT(DISTINCT ck.conversation_id) as num_conversations, + COUNT(DISTINCT nk.note_id) as num_notes + FROM rag_qa_keywords k + LEFT JOIN rag_qa_collection_keywords col_k ON k.id = col_k.keyword_id + LEFT JOIN rag_qa_keyword_collections c ON col_k.collection_id = c.id + LEFT JOIN rag_qa_conversation_keywords ck ON k.id = ck.keyword_id + LEFT JOIN rag_qa_note_keywords nk ON k.id = nk.keyword_id + GROUP BY k.id, k.keyword + ORDER BY k.keyword + """ + + cursor.execute(query) + results = cursor.fetchall() + + # Write to CSV + writer = csv.writer(temp_file) + writer.writerow(['Keyword', 'Collections', 'Number of Conversations', 'Number of Notes']) + + for row in results: + writer.writerow([ + row[0], # keyword + row[1] if row[1] else '', # collections (may be None) + row[2], # num_conversations + row[3] # num_notes + ]) + + temp_file.close() + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + status_msg = f"Successfully exported {len(results)} keywords to CSV." + logger.info(status_msg) + + return status_msg, temp_file.name + + except Exception as e: + error_msg = f"Error exporting keywords: {str(e)}" + logger.error(error_msg) + return error_msg, "" + + # # End of Keyword-related functions ################################################### @@ -339,6 +623,7 @@ def save_notes(conversation_id, title, content): logger.error(f"Error saving notes for conversation '{conversation_id}': {e}") raise + def update_note(note_id, title, content): try: query = "UPDATE rag_qa_notes SET title = ?, content = ?, timestamp = ? WHERE id = ?" @@ -349,6 +634,121 @@ def update_note(note_id, title, content): logger.error(f"Error updating note ID '{note_id}': {e}") raise + +def search_notes_titles(search_term: str, page: int = 1, results_per_page: int = 20, connection=None) -> Tuple[ + List[Tuple], int, int]: + """ + Search note titles using full-text search. Returns all notes if search_term is empty. + + Args: + search_term (str): The search term for note titles. If empty, returns all notes. + page (int, optional): Page number for pagination. Defaults to 1. + results_per_page (int, optional): Number of results per page. Defaults to 20. + connection (sqlite3.Connection, optional): Database connection. Uses new connection if not provided. + + Returns: + Tuple[List[Tuple], int, int]: Tuple containing: + - List of tuples: (note_id, title, content, timestamp, conversation_id) + - Total number of pages + - Total count of matching records + + Raises: + ValueError: If page number is less than 1 + sqlite3.Error: If there's a database error + """ + if page < 1: + raise ValueError("Page number must be 1 or greater.") + + offset = (page - 1) * results_per_page + + def execute_search(conn): + cursor = conn.cursor() + + # Debug: Show table contents + cursor.execute("SELECT title FROM rag_qa_notes") + main_titles = cursor.fetchall() + logger.debug(f"Main table titles: {main_titles}") + + cursor.execute("SELECT title FROM rag_qa_notes_fts") + fts_titles = cursor.fetchall() + logger.debug(f"FTS table titles: {fts_titles}") + + if not search_term.strip(): + # Query for all notes + cursor.execute( + """ + SELECT COUNT(*) + FROM rag_qa_notes + """ + ) + total_count = cursor.fetchone()[0] + + cursor.execute( + """ + SELECT id, title, content, timestamp, conversation_id + FROM rag_qa_notes + ORDER BY timestamp DESC + LIMIT ? OFFSET ? + """, + (results_per_page, offset) + ) + results = cursor.fetchall() + else: + # Search query + search_term_clean = search_term.strip().lower() + + # Test direct FTS search + cursor.execute( + """ + SELECT COUNT(*) + FROM rag_qa_notes n + JOIN rag_qa_notes_fts fts ON n.id = fts.rowid + WHERE fts.title MATCH ? + """, + (search_term_clean,) + ) + total_count = cursor.fetchone()[0] + + cursor.execute( + """ + SELECT + n.id, + n.title, + n.content, + n.timestamp, + n.conversation_id + FROM rag_qa_notes n + JOIN rag_qa_notes_fts fts ON n.id = fts.rowid + WHERE fts.title MATCH ? + ORDER BY rank + LIMIT ? OFFSET ? + """, + (search_term_clean, results_per_page, offset) + ) + results = cursor.fetchall() + + logger.debug(f"Search term: {search_term_clean}") + logger.debug(f"Results: {results}") + + total_pages = max(1, (total_count + results_per_page - 1) // results_per_page) + logger.info(f"Found {total_count} matching notes for search term '{search_term}'") + + return results, total_pages, total_count + + try: + if connection: + return execute_search(connection) + else: + with get_db_connection() as conn: + return execute_search(conn) + + except sqlite3.Error as e: + logger.error(f"Database error in search_notes_titles: {str(e)}") + logger.error(f"Search term: {search_term}") + raise sqlite3.Error(f"Error searching notes: {str(e)}") + + + def get_notes(conversation_id): """Retrieve notes for a given conversation.""" try: @@ -361,6 +761,7 @@ def get_notes(conversation_id): logger.error(f"Error getting notes for conversation '{conversation_id}': {e}") raise + def get_note_by_id(note_id): try: query = "SELECT id, title, content FROM rag_qa_notes WHERE id = ?" @@ -370,9 +771,21 @@ def get_note_by_id(note_id): logger.error(f"Error getting note by ID '{note_id}': {e}") raise + def get_notes_by_keywords(keywords, page=1, page_size=20): try: - placeholders = ','.join(['?'] * len(keywords)) + # Handle empty or invalid keywords + if not keywords or not isinstance(keywords, (list, tuple)) or len(keywords) == 0: + return [], 0, 0 + + # Convert all keywords to strings and strip them + clean_keywords = [str(k).strip() for k in keywords if k is not None and str(k).strip()] + + # If no valid keywords after cleaning, return empty result + if not clean_keywords: + return [], 0, 0 + + placeholders = ','.join(['?'] * len(clean_keywords)) query = f''' SELECT n.id, n.title, n.content, n.timestamp FROM rag_qa_notes n @@ -381,14 +794,15 @@ def get_notes_by_keywords(keywords, page=1, page_size=20): WHERE k.keyword IN ({placeholders}) ORDER BY n.timestamp DESC ''' - results, total_pages, total_count = get_paginated_results(query, tuple(keywords), page, page_size) - logger.info(f"Retrieved {len(results)} notes matching keywords: {', '.join(keywords)} (page {page} of {total_pages})") + results, total_pages, total_count = get_paginated_results(query, tuple(clean_keywords), page, page_size) + logger.info(f"Retrieved {len(results)} notes matching keywords: {', '.join(clean_keywords)} (page {page} of {total_pages})") notes = [(row[0], row[1], row[2], row[3]) for row in results] return notes, total_pages, total_count except Exception as e: logger.error(f"Error getting notes by keywords: {e}") raise + def get_notes_by_keyword_collection(collection_name, page=1, page_size=20): try: query = ''' @@ -501,9 +915,10 @@ def delete_note(note_id): # # Chat-related functions -def save_message(conversation_id, role, content): +def save_message(conversation_id, role, content, timestamp=None): try: - timestamp = datetime.now().isoformat() + if timestamp is None: + timestamp = datetime.now().isoformat() query = "INSERT INTO rag_qa_chats (conversation_id, timestamp, role, content) VALUES (?, ?, ?, ?)" execute_query(query, (conversation_id, timestamp, role, content)) @@ -516,29 +931,103 @@ def save_message(conversation_id, role, content): logger.error(f"Error saving message for conversation '{conversation_id}': {e}") raise -def start_new_conversation(title="Untitled Conversation"): + +def start_new_conversation(title="Untitled Conversation", media_id=None): try: conversation_id = str(uuid.uuid4()) - query = "INSERT INTO conversation_metadata (conversation_id, created_at, last_updated, title) VALUES (?, ?, ?, ?)" + query = """ + INSERT INTO conversation_metadata ( + conversation_id, created_at, last_updated, title, media_id, rating + ) VALUES (?, ?, ?, ?, ?, ?) + """ now = datetime.now().isoformat() - execute_query(query, (conversation_id, now, now, title)) - logger.info(f"New conversation '{conversation_id}' started with title '{title}'") + # Set initial rating to NULL + execute_query(query, (conversation_id, now, now, title, media_id, None)) + logger.info(f"New conversation '{conversation_id}' started with title '{title}' and media_id '{media_id}'") return conversation_id except Exception as e: logger.error(f"Error starting new conversation: {e}") raise + def get_all_conversations(page=1, page_size=20): try: - query = "SELECT conversation_id, title FROM conversation_metadata ORDER BY last_updated DESC" - results, total_pages, total_count = get_paginated_results(query, page=page, page_size=page_size) - conversations = [(row[0], row[1]) for row in results] - logger.info(f"Retrieved {len(conversations)} conversations (page {page} of {total_pages})") - return conversations, total_pages, total_count + query = """ + SELECT conversation_id, title, media_id, rating + FROM conversation_metadata + ORDER BY last_updated DESC + LIMIT ? OFFSET ? + """ + + count_query = "SELECT COUNT(*) FROM conversation_metadata" + db_path = get_rag_qa_db_path() + with sqlite3.connect(db_path) as conn: + cursor = conn.cursor() + + # Get total count + cursor.execute(count_query) + total_count = cursor.fetchone()[0] + total_pages = (total_count + page_size - 1) // page_size + + # Get page of results + offset = (page - 1) * page_size + cursor.execute(query, (page_size, offset)) + results = cursor.fetchall() + + conversations = [{ + 'conversation_id': row[0], + 'title': row[1], + 'media_id': row[2], + 'rating': row[3] # Include rating + } for row in results] + return conversations, total_pages, total_count except Exception as e: - logger.error(f"Error getting conversations: {e}") + logging.error(f"Error getting conversations: {e}") raise + +def get_all_notes(page=1, page_size=20): + try: + query = """ + SELECT n.id, n.conversation_id, n.title, n.content, n.timestamp, + cm.title as conversation_title, cm.media_id + FROM rag_qa_notes n + LEFT JOIN conversation_metadata cm ON n.conversation_id = cm.conversation_id + ORDER BY n.timestamp DESC + LIMIT ? OFFSET ? + """ + + count_query = "SELECT COUNT(*) FROM rag_qa_notes" + db_path = get_rag_qa_db_path() + with sqlite3.connect(db_path) as conn: + cursor = conn.cursor() + + # Get total count + cursor.execute(count_query) + total_count = cursor.fetchone()[0] + total_pages = (total_count + page_size - 1) // page_size + + # Get page of results + offset = (page - 1) * page_size + cursor.execute(query, (page_size, offset)) + results = cursor.fetchall() + + notes = [{ + 'id': row[0], + 'conversation_id': row[1], + 'title': row[2], + 'content': row[3], + 'timestamp': row[4], + 'conversation_title': row[5], + 'media_id': row[6] + } for row in results] + + return notes, total_pages, total_count + except Exception as e: + logging.error(f"Error getting notes: {e}") + raise + + # Pagination helper function def get_paginated_results(query, params=None, page=1, page_size=20): try: @@ -564,6 +1053,7 @@ def get_paginated_results(query, params=None, page=1, page_size=20): logger.error(f"Error retrieving paginated results: {e}") raise + def get_all_collections(page=1, page_size=20): try: query = "SELECT name FROM rag_qa_keyword_collections" @@ -575,24 +1065,79 @@ def get_all_collections(page=1, page_size=20): logger.error(f"Error getting collections: {e}") raise -def search_conversations_by_keywords(keywords, page=1, page_size=20): + +def search_conversations_by_keywords(keywords=None, title_query=None, content_query=None, page=1, page_size=20): try: - placeholders = ','.join(['?' for _ in keywords]) - query = f''' - SELECT DISTINCT cm.conversation_id, cm.title + # Base query starts with conversation metadata + query = """ + SELECT DISTINCT cm.conversation_id, cm.title, cm.last_updated FROM conversation_metadata cm - JOIN rag_qa_conversation_keywords ck ON cm.conversation_id = ck.conversation_id - JOIN rag_qa_keywords k ON ck.keyword_id = k.id - WHERE k.keyword IN ({placeholders}) - ''' - results, total_pages, total_count = get_paginated_results(query, tuple(keywords), page, page_size) - logger.info( - f"Found {total_count} conversations matching keywords: {', '.join(keywords)} (page {page} of {total_pages})") - return results, total_pages, total_count + WHERE 1=1 + """ + params = [] + + # Add content search if provided + if content_query and isinstance(content_query, str) and content_query.strip(): + query += """ + AND EXISTS ( + SELECT 1 FROM rag_qa_chats_fts + WHERE rag_qa_chats_fts.content MATCH ? + AND rag_qa_chats_fts.rowid IN ( + SELECT id FROM rag_qa_chats + WHERE conversation_id = cm.conversation_id + ) + ) + """ + params.append(content_query.strip()) + + # Add title search if provided + if title_query and isinstance(title_query, str) and title_query.strip(): + query += """ + AND EXISTS ( + SELECT 1 FROM conversation_metadata_fts + WHERE conversation_metadata_fts.title MATCH ? + AND conversation_metadata_fts.rowid = cm.rowid + ) + """ + params.append(title_query.strip()) + + # Add keyword search if provided + if keywords and isinstance(keywords, (list, tuple)) and len(keywords) > 0: + # Convert all keywords to strings and strip them + clean_keywords = [str(k).strip() for k in keywords if k is not None and str(k).strip()] + if clean_keywords: # Only add to query if we have valid keywords + placeholders = ','.join(['?' for _ in clean_keywords]) + query += f""" + AND EXISTS ( + SELECT 1 FROM rag_qa_conversation_keywords ck + JOIN rag_qa_keywords k ON ck.keyword_id = k.id + WHERE ck.conversation_id = cm.conversation_id + AND k.keyword IN ({placeholders}) + ) + """ + params.extend(clean_keywords) + + # Add ordering + query += " ORDER BY cm.last_updated DESC" + + results, total_pages, total_count = get_paginated_results(query, tuple(params), page, page_size) + + conversations = [ + { + 'conversation_id': row[0], + 'title': row[1], + 'last_updated': row[2] + } + for row in results + ] + + return conversations, total_pages, total_count + except Exception as e: - logger.error(f"Error searching conversations by keywords {keywords}: {e}") + logger.error(f"Error searching conversations: {e}") raise + def load_chat_history(conversation_id, page=1, page_size=50): try: query = "SELECT role, content FROM rag_qa_chats WHERE conversation_id = ? ORDER BY timestamp" @@ -604,6 +1149,7 @@ def load_chat_history(conversation_id, page=1, page_size=50): logger.error(f"Error loading chat history for conversation '{conversation_id}': {e}") raise + def update_conversation_title(conversation_id, new_title): """Update the title of a conversation.""" try: @@ -614,6 +1160,7 @@ def update_conversation_title(conversation_id, new_title): logger.error(f"Error updating conversation title: {e}") raise + def delete_messages_in_conversation(conversation_id): """Helper function to delete all messages in a conversation.""" try: @@ -623,6 +1170,7 @@ def delete_messages_in_conversation(conversation_id): logging.error(f"Error deleting messages in conversation '{conversation_id}': {e}") raise + def get_conversation_title(conversation_id): """Helper function to get the conversation title.""" query = "SELECT title FROM conversation_metadata WHERE conversation_id = ?" @@ -632,6 +1180,39 @@ def get_conversation_title(conversation_id): else: return "Untitled Conversation" + +def get_conversation_text(conversation_id): + try: + query = """ + SELECT role, content + FROM rag_qa_chats + WHERE conversation_id = ? + ORDER BY timestamp ASC + """ + + messages = [] + # Use the connection as a context manager + db_path = get_rag_qa_db_path() + with sqlite3.connect(db_path) as conn: + cursor = conn.cursor() + cursor.execute(query, (conversation_id,)) + messages = cursor.fetchall() + + return "\n\n".join([f"{msg[0]}: {msg[1]}" for msg in messages]) + except Exception as e: + logger.error(f"Error getting conversation text: {e}") + raise + + +def get_conversation_details(conversation_id): + query = "SELECT title, media_id, rating FROM conversation_metadata WHERE conversation_id = ?" + result = execute_query(query, (conversation_id,)) + if result: + return {'title': result[0][0], 'media_id': result[0][1], 'rating': result[0][2]} + else: + return {'title': "Untitled Conversation", 'media_id': None, 'rating': None} + + def delete_conversation(conversation_id): """Delete a conversation and its associated messages and notes.""" try: @@ -651,11 +1232,203 @@ def delete_conversation(conversation_id): logger.error(f"Error deleting conversation '{conversation_id}': {e}") raise +def set_conversation_rating(conversation_id, rating): + """Set the rating for a conversation.""" + # Validate rating + if rating not in [1, 2, 3]: + raise ValueError('Rating must be an integer between 1 and 3.') + try: + query = "UPDATE conversation_metadata SET rating = ? WHERE conversation_id = ?" + execute_query(query, (rating, conversation_id)) + logger.info(f"Rating for conversation '{conversation_id}' set to {rating}") + except Exception as e: + logger.error(f"Error setting rating for conversation '{conversation_id}': {e}") + raise + +def get_conversation_rating(conversation_id): + """Get the rating of a conversation.""" + try: + query = "SELECT rating FROM conversation_metadata WHERE conversation_id = ?" + result = execute_query(query, (conversation_id,)) + if result: + rating = result[0][0] + logger.info(f"Rating for conversation '{conversation_id}' is {rating}") + return rating + else: + logger.warning(f"Conversation '{conversation_id}' not found.") + return None + except Exception as e: + logger.error(f"Error getting rating for conversation '{conversation_id}': {e}") + raise + + +def get_conversation_name(conversation_id: str) -> str: + """ + Retrieves the title/name of a conversation from the conversation_metadata table. + + Args: + conversation_id (str): The unique identifier of the conversation + + Returns: + str: The title of the conversation if found, "Untitled Conversation" if not found + + Raises: + sqlite3.Error: If there's a database error + """ + try: + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT title FROM conversation_metadata WHERE conversation_id = ?", + (conversation_id,) + ) + result = cursor.fetchone() + + if result: + return result[0] + else: + logging.warning(f"No conversation found with ID: {conversation_id}") + return "Untitled Conversation" + + except sqlite3.Error as e: + logging.error(f"Database error retrieving conversation name for ID {conversation_id}: {e}") + raise + except Exception as e: + logging.error(f"Unexpected error retrieving conversation name for ID {conversation_id}: {e}") + raise + + +def search_rag_chat(query: str, fts_top_k: int = 10, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]: + """ + Perform a full-text search on the RAG Chat database. + + Args: + query: Search query string. + fts_top_k: Maximum number of results to return. + relevant_media_ids: Optional list of media IDs to filter results. + + Returns: + List of search results with content and metadata. + """ + if not query.strip(): + return [] + + try: + db_path = get_rag_qa_db_path() + with sqlite3.connect(db_path) as conn: + cursor = conn.cursor() + # Perform the full-text search using the FTS virtual table + cursor.execute(""" + SELECT rag_qa_chats.id, rag_qa_chats.conversation_id, rag_qa_chats.role, rag_qa_chats.content + FROM rag_qa_chats_fts + JOIN rag_qa_chats ON rag_qa_chats_fts.rowid = rag_qa_chats.id + WHERE rag_qa_chats_fts MATCH ? + LIMIT ? + """, (query, fts_top_k)) + + rows = cursor.fetchall() + columns = [description[0] for description in cursor.description] + results = [dict(zip(columns, row)) for row in rows] + + # Filter by relevant_media_ids if provided + if relevant_media_ids is not None: + results = [ + r for r in results + if get_conversation_details(r['conversation_id']).get('media_id') in relevant_media_ids + ] + + # Format results + formatted_results = [ + { + "content": r['content'], + "metadata": { + "conversation_id": r['conversation_id'], + "role": r['role'], + "media_id": get_conversation_details(r['conversation_id']).get('media_id') + } + } + for r in results + ] + return formatted_results + + except Exception as e: + logging.error(f"Error in search_rag_chat: {e}") + return [] + + +def search_rag_notes(query: str, fts_top_k: int = 10, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]: + """ + Perform a full-text search on the RAG Notes database. + + Args: + query: Search query string. + fts_top_k: Maximum number of results to return. + relevant_media_ids: Optional list of media IDs to filter results. + + Returns: + List of search results with content and metadata. + """ + if not query.strip(): + return [] + + try: + db_path = get_rag_qa_db_path() + with sqlite3.connect(db_path) as conn: + cursor = conn.cursor() + # Perform the full-text search using the FTS virtual table + cursor.execute(""" + SELECT rag_qa_notes.id, rag_qa_notes.title, rag_qa_notes.content, rag_qa_notes.conversation_id + FROM rag_qa_notes_fts + JOIN rag_qa_notes ON rag_qa_notes_fts.rowid = rag_qa_notes.id + WHERE rag_qa_notes_fts MATCH ? + LIMIT ? + """, (query, fts_top_k)) + + rows = cursor.fetchall() + columns = [description[0] for description in cursor.description] + results = [dict(zip(columns, row)) for row in rows] + + # Filter by relevant_media_ids if provided + if relevant_media_ids is not None: + results = [ + r for r in results + if get_conversation_details(r['conversation_id']).get('media_id') in relevant_media_ids + ] + + # Format results + formatted_results = [ + { + "content": r['content'], + "metadata": { + "note_id": r['id'], + "title": r['title'], + "conversation_id": r['conversation_id'], + "media_id": get_conversation_details(r['conversation_id']).get('media_id') + } + } + for r in results + ] + return formatted_results + + except Exception as e: + logging.error(f"Error in search_rag_notes: {e}") + return [] + # # End of Chat-related functions ################################################### +################################################### +# +# Import functions + + +# +# End of Import functions +################################################### + + ################################################### # # Functions to export DB data diff --git a/App_Function_Libraries/DB/SQLite_DB.py b/App_Function_Libraries/DB/SQLite_DB.py index 7ec3bf2e5..6c6c44024 100644 --- a/App_Function_Libraries/DB/SQLite_DB.py +++ b/App_Function_Libraries/DB/SQLite_DB.py @@ -21,7 +21,7 @@ # 11. browse_items(search_query, search_type) # 12. fetch_item_details(media_id: int) # 13. add_media_version(media_id: int, prompt: str, summary: str) -# 14. search_db(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 10) +# 14. search_media_db(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 10) # 15. search_and_display(search_query, search_fields, keywords, page) # 16. display_details(index, results) # 17. get_details(index, dataframe) @@ -55,12 +55,14 @@ import shutil import sqlite3 import threading +import time import traceback from contextlib import contextmanager from datetime import datetime, timedelta from typing import List, Tuple, Dict, Any, Optional from urllib.parse import quote +from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram # Local Libraries from App_Function_Libraries.Utils.Utils import get_project_relative_path, get_database_path, \ get_database_dir @@ -342,27 +344,6 @@ def create_tables(db) -> None: ) ''', ''' - CREATE TABLE IF NOT EXISTS ChatConversations ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - media_id INTEGER, - media_name TEXT, - conversation_name TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (media_id) REFERENCES Media(id) - ) - ''', - ''' - CREATE TABLE IF NOT EXISTS ChatMessages ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - conversation_id INTEGER, - sender TEXT, - message TEXT, - timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (conversation_id) REFERENCES ChatConversations(id) - ) - ''', - ''' CREATE TABLE IF NOT EXISTS Transcripts ( id INTEGER PRIMARY KEY AUTOINCREMENT, media_id INTEGER, @@ -421,8 +402,6 @@ def create_tables(db) -> None: 'CREATE INDEX IF NOT EXISTS idx_mediakeywords_keyword_id ON MediaKeywords(keyword_id)', 'CREATE INDEX IF NOT EXISTS idx_media_version_media_id ON MediaVersion(media_id)', 'CREATE INDEX IF NOT EXISTS idx_mediamodifications_media_id ON MediaModifications(media_id)', - 'CREATE INDEX IF NOT EXISTS idx_chatconversations_media_id ON ChatConversations(media_id)', - 'CREATE INDEX IF NOT EXISTS idx_chatmessages_conversation_id ON ChatMessages(conversation_id)', 'CREATE INDEX IF NOT EXISTS idx_media_is_trash ON Media(is_trash)', 'CREATE INDEX IF NOT EXISTS idx_mediachunks_media_id ON MediaChunks(media_id)', 'CREATE INDEX IF NOT EXISTS idx_unvectorized_media_chunks_media_id ON UnvectorizedMediaChunks(media_id)', @@ -606,7 +585,10 @@ def mark_media_as_processed(database, media_id): # Function to add media with keywords def add_media_with_keywords(url, title, media_type, content, keywords, prompt, summary, transcription_model, author, ingestion_date): + log_counter("add_media_with_keywords_attempt") + start_time = time.time() logging.debug(f"Entering add_media_with_keywords: URL={url}, Title={title}") + # Set default values for missing fields if url is None: url = 'localhost' @@ -622,10 +604,17 @@ def add_media_with_keywords(url, title, media_type, content, keywords, prompt, s author = author or 'Unknown' ingestion_date = ingestion_date or datetime.now().strftime('%Y-%m-%d') - if media_type not in ['article', 'audio', 'document', 'mediawiki_article', 'mediawiki_dump', 'obsidian_note', 'podcast', 'text', 'video', 'unknown']: - raise InputError("Invalid media type. Allowed types: article, audio file, document, obsidian_note podcast, text, video, unknown.") + if media_type not in ['article', 'audio', 'book', 'document', 'mediawiki_article', 'mediawiki_dump', + 'obsidian_note', 'podcast', 'text', 'video', 'unknown']: + log_counter("add_media_with_keywords_error", labels={"error_type": "InvalidMediaType"}) + duration = time.time() - start_time + log_histogram("add_media_with_keywords_duration", duration) + raise InputError("Invalid media type. Allowed types: article, audio file, document, obsidian_note, podcast, text, video, unknown.") if ingestion_date and not is_valid_date(ingestion_date): + log_counter("add_media_with_keywords_error", labels={"error_type": "InvalidDateFormat"}) + duration = time.time() - start_time + log_histogram("add_media_with_keywords_duration", duration) raise InputError("Invalid ingestion date format. Use YYYY-MM-DD.") # Handle keywords as either string or list @@ -654,6 +643,7 @@ def add_media_with_keywords(url, title, media_type, content, keywords, prompt, s logging.debug(f"Existing media ID for {url}: {existing_media_id}") if existing_media_id: + # Update existing media media_id = existing_media_id logging.debug(f"Updating existing media with ID: {media_id}") cursor.execute(''' @@ -661,7 +651,9 @@ def add_media_with_keywords(url, title, media_type, content, keywords, prompt, s SET content = ?, transcription_model = ?, type = ?, author = ?, ingestion_date = ? WHERE id = ? ''', (content, transcription_model, media_type, author, ingestion_date, media_id)) + log_counter("add_media_with_keywords_update") else: + # Insert new media logging.debug("Inserting new media") cursor.execute(''' INSERT INTO Media (url, title, type, content, author, ingestion_date, transcription_model) @@ -669,6 +661,7 @@ def add_media_with_keywords(url, title, media_type, content, keywords, prompt, s ''', (url, title, media_type, content, author, ingestion_date, transcription_model)) media_id = cursor.lastrowid logging.debug(f"New media inserted with ID: {media_id}") + log_counter("add_media_with_keywords_insert") cursor.execute(''' INSERT INTO MediaModifications (media_id, prompt, summary, modification_date) @@ -698,13 +691,23 @@ def add_media_with_keywords(url, title, media_type, content, keywords, prompt, s conn.commit() logging.info(f"Media '{title}' successfully added/updated with ID: {media_id}") - return media_id, f"Media '{title}' added/updated successfully with keywords: {', '.join(keyword_list)}" + duration = time.time() - start_time + log_histogram("add_media_with_keywords_duration", duration) + log_counter("add_media_with_keywords_success") + + return media_id, f"Media '{title}' added/updated successfully with keywords: {', '.join(keyword_list)}" except sqlite3.Error as e: logging.error(f"SQL Error in add_media_with_keywords: {e}") + duration = time.time() - start_time + log_histogram("add_media_with_keywords_duration", duration) + log_counter("add_media_with_keywords_error", labels={"error_type": "SQLiteError"}) raise DatabaseError(f"Error adding media with keywords: {e}") except Exception as e: logging.error(f"Unexpected Error in add_media_with_keywords: {e}") + duration = time.time() - start_time + log_histogram("add_media_with_keywords_duration", duration) + log_counter("add_media_with_keywords_error", labels={"error_type": type(e).__name__}) raise DatabaseError(f"Unexpected error: {e}") @@ -779,7 +782,13 @@ def ingest_article_to_db(url, title, author, content, keywords, summary, ingesti # Function to add a keyword def add_keyword(keyword: str) -> int: + log_counter("add_keyword_attempt") + start_time = time.time() + if not keyword.strip(): + log_counter("add_keyword_error", labels={"error_type": "EmptyKeyword"}) + duration = time.time() - start_time + log_histogram("add_keyword_duration", duration) raise DatabaseError("Keyword cannot be empty") keyword = keyword.strip().lower() @@ -801,18 +810,32 @@ def add_keyword(keyword: str) -> int: logging.info(f"Keyword '{keyword}' added or updated with ID: {keyword_id}") conn.commit() + + duration = time.time() - start_time + log_histogram("add_keyword_duration", duration) + log_counter("add_keyword_success") + return keyword_id except sqlite3.IntegrityError as e: logging.error(f"Integrity error adding keyword: {e}") + duration = time.time() - start_time + log_histogram("add_keyword_duration", duration) + log_counter("add_keyword_error", labels={"error_type": "IntegrityError"}) raise DatabaseError(f"Integrity error adding keyword: {e}") except sqlite3.Error as e: logging.error(f"Error adding keyword: {e}") + duration = time.time() - start_time + log_histogram("add_keyword_duration", duration) + log_counter("add_keyword_error", labels={"error_type": "SQLiteError"}) raise DatabaseError(f"Error adding keyword: {e}") # Function to delete a keyword def delete_keyword(keyword: str) -> str: + log_counter("delete_keyword_attempt") + start_time = time.time() + keyword = keyword.strip().lower() with db.get_connection() as conn: cursor = conn.cursor() @@ -823,10 +846,23 @@ def delete_keyword(keyword: str) -> str: cursor.execute('DELETE FROM Keywords WHERE keyword = ?', (keyword,)) cursor.execute('DELETE FROM keyword_fts WHERE rowid = ?', (keyword_id[0],)) conn.commit() + + duration = time.time() - start_time + log_histogram("delete_keyword_duration", duration) + log_counter("delete_keyword_success") + return f"Keyword '{keyword}' deleted successfully." else: + duration = time.time() - start_time + log_histogram("delete_keyword_duration", duration) + log_counter("delete_keyword_not_found") + return f"Keyword '{keyword}' not found." except sqlite3.Error as e: + duration = time.time() - start_time + log_histogram("delete_keyword_duration", duration) + log_counter("delete_keyword_error", labels={"error_type": type(e).__name__}) + logging.error(f"Error deleting keyword: {e}") raise DatabaseError(f"Error deleting keyword: {e}") @@ -1000,7 +1036,7 @@ def add_media_version(conn, media_id: int, prompt: str, summary: str) -> None: # Function to search the database with advanced options, including keyword search and full-text search -def sqlite_search_db(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 10, connection=None): +def search_media_db(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 20, connection=None): if page < 1: raise ValueError("Page number must be 1 or greater.") @@ -1055,7 +1091,7 @@ def execute_query(conn): # Gradio function to handle user input and display results with pagination, with better feedback def search_and_display(search_query, search_fields, keywords, page): - results = sqlite_search_db(search_query, search_fields, keywords, page) + results = search_media_db(search_query, search_fields, keywords, page) if isinstance(results, pd.DataFrame): # Convert DataFrame to a list of tuples or lists @@ -1133,7 +1169,7 @@ def format_results(results): # Function to export search results to CSV or markdown with pagination def export_to_file(search_query: str, search_fields: List[str], keyword: str, page: int = 1, results_per_file: int = 1000, export_format: str = 'csv'): try: - results = sqlite_search_db(search_query, search_fields, keyword, page, results_per_file) + results = search_media_db(search_query, search_fields, keyword, page, results_per_file) if not results: return "No results found to export." @@ -1381,303 +1417,6 @@ def schedule_chunking(media_id: int, content: str, media_name: str): ####################################################################################################################### -####################################################################################################################### -# -# Functions to manage prompts DB - -def create_prompts_db(): - logging.debug("create_prompts_db: Creating prompts database.") - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - cursor.executescript(''' - CREATE TABLE IF NOT EXISTS Prompts ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL UNIQUE, - author TEXT, - details TEXT, - system TEXT, - user TEXT - ); - CREATE TABLE IF NOT EXISTS Keywords ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - keyword TEXT NOT NULL UNIQUE COLLATE NOCASE - ); - CREATE TABLE IF NOT EXISTS PromptKeywords ( - prompt_id INTEGER, - keyword_id INTEGER, - FOREIGN KEY (prompt_id) REFERENCES Prompts (id), - FOREIGN KEY (keyword_id) REFERENCES Keywords (id), - PRIMARY KEY (prompt_id, keyword_id) - ); - CREATE INDEX IF NOT EXISTS idx_keywords_keyword ON Keywords(keyword); - CREATE INDEX IF NOT EXISTS idx_promptkeywords_prompt_id ON PromptKeywords(prompt_id); - CREATE INDEX IF NOT EXISTS idx_promptkeywords_keyword_id ON PromptKeywords(keyword_id); - ''') - -# FIXME - dirty hack that should be removed later... -# Migration function to add the 'author' column to the Prompts table -def add_author_column_to_prompts(): - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - # Check if 'author' column already exists - cursor.execute("PRAGMA table_info(Prompts)") - columns = [col[1] for col in cursor.fetchall()] - - if 'author' not in columns: - # Add the 'author' column - cursor.execute('ALTER TABLE Prompts ADD COLUMN author TEXT') - print("Author column added to Prompts table.") - else: - print("Author column already exists in Prompts table.") - -add_author_column_to_prompts() - -def normalize_keyword(keyword): - return re.sub(r'\s+', ' ', keyword.strip().lower()) - - -# FIXME - update calls to this function to use the new args -def add_prompt(name, author, details, system=None, user=None, keywords=None): - logging.debug(f"add_prompt: Adding prompt with name: {name}, author: {author}, system: {system}, user: {user}, keywords: {keywords}") - if not name: - logging.error("add_prompt: A name is required.") - return "A name is required." - - try: - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - cursor.execute(''' - INSERT INTO Prompts (name, author, details, system, user) - VALUES (?, ?, ?, ?, ?) - ''', (name, author, details, system, user)) - prompt_id = cursor.lastrowid - - if keywords: - normalized_keywords = [normalize_keyword(k) for k in keywords if k.strip()] - for keyword in set(normalized_keywords): # Use set to remove duplicates - cursor.execute(''' - INSERT OR IGNORE INTO Keywords (keyword) VALUES (?) - ''', (keyword,)) - cursor.execute('SELECT id FROM Keywords WHERE keyword = ?', (keyword,)) - keyword_id = cursor.fetchone()[0] - cursor.execute(''' - INSERT OR IGNORE INTO PromptKeywords (prompt_id, keyword_id) VALUES (?, ?) - ''', (prompt_id, keyword_id)) - return "Prompt added successfully." - except sqlite3.IntegrityError: - return "Prompt with this name already exists." - except sqlite3.Error as e: - return f"Database error: {e}" - - -def fetch_prompt_details(name): - logging.debug(f"fetch_prompt_details: Fetching details for prompt: {name}") - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - cursor.execute(''' - SELECT p.name, p.author, p.details, p.system, p.user, GROUP_CONCAT(k.keyword, ', ') as keywords - FROM Prompts p - LEFT JOIN PromptKeywords pk ON p.id = pk.prompt_id - LEFT JOIN Keywords k ON pk.keyword_id = k.id - WHERE p.name = ? - GROUP BY p.id - ''', (name,)) - return cursor.fetchone() - - -def list_prompts(page=1, per_page=10): - logging.debug(f"list_prompts: Listing prompts for page {page} with {per_page} prompts per page.") - offset = (page - 1) * per_page - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - cursor.execute('SELECT name FROM Prompts LIMIT ? OFFSET ?', (per_page, offset)) - prompts = [row[0] for row in cursor.fetchall()] - - # Get total count of prompts - cursor.execute('SELECT COUNT(*) FROM Prompts') - total_count = cursor.fetchone()[0] - - total_pages = (total_count + per_page - 1) // per_page - return prompts, total_pages, page - -# This will not scale. For a large number of prompts, use a more efficient method. -# FIXME - see above statement. -def load_preset_prompts(): - logging.debug("load_preset_prompts: Loading preset prompts.") - try: - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - cursor.execute('SELECT name FROM Prompts ORDER BY name ASC') - prompts = [row[0] for row in cursor.fetchall()] - return prompts - except sqlite3.Error as e: - print(f"Database error: {e}") - return [] - - -def insert_prompt_to_db(title, author, description, system_prompt, user_prompt, keywords=None): - return add_prompt(title, author, description, system_prompt, user_prompt, keywords) - - -def get_prompt_db_connection(): - prompt_db_path = get_database_path('prompts.db') - return sqlite3.connect(prompt_db_path) - - -def search_prompts(query): - logging.debug(f"search_prompts: Searching prompts with query: {query}") - try: - with get_prompt_db_connection() as conn: - cursor = conn.cursor() - cursor.execute(""" - SELECT p.name, p.details, p.system, p.user, GROUP_CONCAT(k.keyword, ', ') as keywords - FROM Prompts p - LEFT JOIN PromptKeywords pk ON p.id = pk.prompt_id - LEFT JOIN Keywords k ON pk.keyword_id = k.id - WHERE p.name LIKE ? OR p.details LIKE ? OR p.system LIKE ? OR p.user LIKE ? OR k.keyword LIKE ? - GROUP BY p.id - ORDER BY p.name - """, (f'%{query}%', f'%{query}%', f'%{query}%', f'%{query}%', f'%{query}%')) - return cursor.fetchall() - except sqlite3.Error as e: - logging.error(f"Error searching prompts: {e}") - return [] - - -def search_prompts_by_keyword(keyword, page=1, per_page=10): - logging.debug(f"search_prompts_by_keyword: Searching prompts by keyword: {keyword}") - normalized_keyword = normalize_keyword(keyword) - offset = (page - 1) * per_page - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - cursor.execute(''' - SELECT DISTINCT p.name - FROM Prompts p - JOIN PromptKeywords pk ON p.id = pk.prompt_id - JOIN Keywords k ON pk.keyword_id = k.id - WHERE k.keyword LIKE ? - LIMIT ? OFFSET ? - ''', ('%' + normalized_keyword + '%', per_page, offset)) - prompts = [row[0] for row in cursor.fetchall()] - - # Get total count of matching prompts - cursor.execute(''' - SELECT COUNT(DISTINCT p.id) - FROM Prompts p - JOIN PromptKeywords pk ON p.id = pk.prompt_id - JOIN Keywords k ON pk.keyword_id = k.id - WHERE k.keyword LIKE ? - ''', ('%' + normalized_keyword + '%',)) - total_count = cursor.fetchone()[0] - - total_pages = (total_count + per_page - 1) // per_page - return prompts, total_pages, page - - -def update_prompt_keywords(prompt_name, new_keywords): - logging.debug(f"update_prompt_keywords: Updating keywords for prompt: {prompt_name}") - try: - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - - cursor.execute('SELECT id FROM Prompts WHERE name = ?', (prompt_name,)) - prompt_id = cursor.fetchone() - if not prompt_id: - return "Prompt not found." - prompt_id = prompt_id[0] - - cursor.execute('DELETE FROM PromptKeywords WHERE prompt_id = ?', (prompt_id,)) - - normalized_keywords = [normalize_keyword(k) for k in new_keywords if k.strip()] - for keyword in set(normalized_keywords): # Use set to remove duplicates - cursor.execute('INSERT OR IGNORE INTO Keywords (keyword) VALUES (?)', (keyword,)) - cursor.execute('SELECT id FROM Keywords WHERE keyword = ?', (keyword,)) - keyword_id = cursor.fetchone()[0] - cursor.execute('INSERT INTO PromptKeywords (prompt_id, keyword_id) VALUES (?, ?)', - (prompt_id, keyword_id)) - - # Remove unused keywords - cursor.execute(''' - DELETE FROM Keywords - WHERE id NOT IN (SELECT DISTINCT keyword_id FROM PromptKeywords) - ''') - return "Keywords updated successfully." - except sqlite3.Error as e: - return f"Database error: {e}" - - -def add_or_update_prompt(title, author, description, system_prompt, user_prompt, keywords=None): - logging.debug(f"add_or_update_prompt: Adding or updating prompt: {title}") - if not title: - return "Error: Title is required." - - existing_prompt = fetch_prompt_details(title) - if existing_prompt: - # Update existing prompt - result = update_prompt_in_db(title, author, description, system_prompt, user_prompt) - if "successfully" in result: - # Update keywords if the prompt update was successful - keyword_result = update_prompt_keywords(title, keywords or []) - result += f" {keyword_result}" - else: - # Insert new prompt - result = insert_prompt_to_db(title, author, description, system_prompt, user_prompt, keywords) - - return result - - -def load_prompt_details(selected_prompt): - logging.debug(f"load_prompt_details: Loading prompt details for {selected_prompt}") - if selected_prompt: - details = fetch_prompt_details(selected_prompt) - if details: - return details[0], details[1], details[2], details[3], details[4], details[5] - return "", "", "", "", "", "" - - -def update_prompt_in_db(title, author, description, system_prompt, user_prompt): - logging.debug(f"update_prompt_in_db: Updating prompt: {title}") - try: - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - cursor.execute( - "UPDATE Prompts SET author = ?, details = ?, system = ?, user = ? WHERE name = ?", - (author, description, system_prompt, user_prompt, title) - ) - if cursor.rowcount == 0: - return "No prompt found with the given title." - return "Prompt updated successfully!" - except sqlite3.Error as e: - return f"Error updating prompt: {e}" - - -create_prompts_db() - -def delete_prompt(prompt_id): - logging.debug(f"delete_prompt: Deleting prompt with ID: {prompt_id}") - try: - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - - # Delete associated keywords - cursor.execute("DELETE FROM PromptKeywords WHERE prompt_id = ?", (prompt_id,)) - - # Delete the prompt - cursor.execute("DELETE FROM Prompts WHERE id = ?", (prompt_id,)) - - if cursor.rowcount == 0: - return f"No prompt found with ID {prompt_id}" - else: - conn.commit() - return f"Prompt with ID {prompt_id} has been successfully deleted" - except sqlite3.Error as e: - return f"An error occurred: {e}" - -# -# -####################################################################################################################### - - ####################################################################################################################### # # Function to fetch/update media content @@ -2020,204 +1759,6 @@ def import_obsidian_note_to_db(note_data): ####################################################################################################################### -####################################################################################################################### -# -# Chat-related Functions - - - -def create_chat_conversation(media_id, conversation_name): - try: - with db.get_connection() as conn: - cursor = conn.cursor() - cursor.execute(''' - INSERT INTO ChatConversations (media_id, conversation_name, created_at, updated_at) - VALUES (?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) - ''', (media_id, conversation_name)) - conn.commit() - return cursor.lastrowid - except sqlite3.Error as e: - logging.error(f"Error creating chat conversation: {e}") - raise DatabaseError(f"Error creating chat conversation: {e}") - - -def add_chat_message(conversation_id: int, sender: str, message: str) -> int: - try: - with db.get_connection() as conn: - cursor = conn.cursor() - cursor.execute(''' - INSERT INTO ChatMessages (conversation_id, sender, message) - VALUES (?, ?, ?) - ''', (conversation_id, sender, message)) - conn.commit() - return cursor.lastrowid - except sqlite3.Error as e: - logging.error(f"Error adding chat message: {e}") - raise DatabaseError(f"Error adding chat message: {e}") - - -def get_chat_messages(conversation_id: int) -> List[Dict[str, Any]]: - try: - with db.get_connection() as conn: - cursor = conn.cursor() - cursor.execute(''' - SELECT id, sender, message, timestamp - FROM ChatMessages - WHERE conversation_id = ? - ORDER BY timestamp ASC - ''', (conversation_id,)) - messages = cursor.fetchall() - return [ - { - 'id': msg[0], - 'sender': msg[1], - 'message': msg[2], - 'timestamp': msg[3] - } - for msg in messages - ] - except sqlite3.Error as e: - logging.error(f"Error retrieving chat messages: {e}") - raise DatabaseError(f"Error retrieving chat messages: {e}") - - -def search_chat_conversations(search_query: str) -> List[Dict[str, Any]]: - try: - with db.get_connection() as conn: - cursor = conn.cursor() - cursor.execute(''' - SELECT cc.id, cc.media_id, cc.conversation_name, cc.created_at, m.title as media_title - FROM ChatConversations cc - LEFT JOIN Media m ON cc.media_id = m.id - WHERE cc.conversation_name LIKE ? OR m.title LIKE ? - ORDER BY cc.updated_at DESC - ''', (f'%{search_query}%', f'%{search_query}%')) - conversations = cursor.fetchall() - return [ - { - 'id': conv[0], - 'media_id': conv[1], - 'conversation_name': conv[2], - 'created_at': conv[3], - 'media_title': conv[4] or "Unknown Media" - } - for conv in conversations - ] - except sqlite3.Error as e: - logging.error(f"Error searching chat conversations: {e}") - return [] - - -def update_chat_message(message_id: int, new_message: str) -> None: - try: - with db.get_connection() as conn: - cursor = conn.cursor() - cursor.execute(''' - UPDATE ChatMessages - SET message = ?, timestamp = CURRENT_TIMESTAMP - WHERE id = ? - ''', (new_message, message_id)) - conn.commit() - except sqlite3.Error as e: - logging.error(f"Error updating chat message: {e}") - raise DatabaseError(f"Error updating chat message: {e}") - - -def delete_chat_message(message_id: int) -> None: - try: - with db.get_connection() as conn: - cursor = conn.cursor() - cursor.execute('DELETE FROM ChatMessages WHERE id = ?', (message_id,)) - conn.commit() - except sqlite3.Error as e: - logging.error(f"Error deleting chat message: {e}") - raise DatabaseError(f"Error deleting chat message: {e}") - - -def save_chat_history_to_database(chatbot, conversation_id, media_id, media_name, conversation_name): - try: - with db.get_connection() as conn: - cursor = conn.cursor() - - # If conversation_id is None, create a new conversation - if conversation_id is None: - cursor.execute(''' - INSERT INTO ChatConversations (media_id, media_name, conversation_name, created_at, updated_at) - VALUES (?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) - ''', (media_id, media_name, conversation_name)) - conversation_id = cursor.lastrowid - else: - # If conversation exists, update the media_name - cursor.execute(''' - UPDATE ChatConversations - SET media_name = ?, updated_at = CURRENT_TIMESTAMP - WHERE id = ? - ''', (media_name, conversation_id)) - - # Save each message in the chatbot history - for i, (user_msg, ai_msg) in enumerate(chatbot): - cursor.execute(''' - INSERT INTO ChatMessages (conversation_id, sender, message, timestamp) - VALUES (?, ?, ?, CURRENT_TIMESTAMP) - ''', (conversation_id, 'user', user_msg)) - - cursor.execute(''' - INSERT INTO ChatMessages (conversation_id, sender, message, timestamp) - VALUES (?, ?, ?, CURRENT_TIMESTAMP) - ''', (conversation_id, 'ai', ai_msg)) - - # Update the conversation's updated_at timestamp - cursor.execute(''' - UPDATE ChatConversations - SET updated_at = CURRENT_TIMESTAMP - WHERE id = ? - ''', (conversation_id,)) - - conn.commit() - - return conversation_id - except Exception as e: - logging.error(f"Error saving chat history to database: {str(e)}") - raise - - -def get_conversation_name(conversation_id): - if conversation_id is None: - return None - - try: - with sqlite3.connect('media_summary.db') as conn: # Replace with your actual database name - cursor = conn.cursor() - - query = """ - SELECT conversation_name, media_name - FROM ChatConversations - WHERE id = ? - """ - - cursor.execute(query, (conversation_id,)) - result = cursor.fetchone() - - if result: - conversation_name, media_name = result - if conversation_name: - return conversation_name - elif media_name: - return f"{media_name}-chat" - - return None # Return None if no result found - except sqlite3.Error as e: - logging.error(f"Database error in get_conversation_name: {e}") - return None - except Exception as e: - logging.error(f"Unexpected error in get_conversation_name: {e}") - return None - -# -# End of Chat-related Functions -####################################################################################################################### - - ####################################################################################################################### # # Functions to Compare Transcripts @@ -2837,29 +2378,42 @@ def process_chunks(database, chunks: List[Dict], media_id: int, batch_size: int :param media_id: ID of the media these chunks belong to :param batch_size: Number of chunks to process in each batch """ + log_counter("process_chunks_attempt", labels={"media_id": media_id}) + start_time = time.time() total_chunks = len(chunks) processed_chunks = 0 - for i in range(0, total_chunks, batch_size): - batch = chunks[i:i + batch_size] - chunk_data = [ - (media_id, chunk['text'], chunk['start_index'], chunk['end_index']) - for chunk in batch - ] - - try: - database.execute_many( - "INSERT INTO MediaChunks (media_id, chunk_text, start_index, end_index) VALUES (?, ?, ?, ?)", - chunk_data - ) - processed_chunks += len(batch) - logging.info(f"Processed {processed_chunks}/{total_chunks} chunks for media_id {media_id}") - except Exception as e: - logging.error(f"Error inserting chunk batch for media_id {media_id}: {e}") - # Optionally, you could raise an exception here to stop processing - # raise + try: + for i in range(0, total_chunks, batch_size): + batch = chunks[i:i + batch_size] + chunk_data = [ + (media_id, chunk['text'], chunk['start_index'], chunk['end_index']) + for chunk in batch + ] - logging.info(f"Finished processing all {total_chunks} chunks for media_id {media_id}") + try: + database.execute_many( + "INSERT INTO MediaChunks (media_id, chunk_text, start_index, end_index) VALUES (?, ?, ?, ?)", + chunk_data + ) + processed_chunks += len(batch) + logging.info(f"Processed {processed_chunks}/{total_chunks} chunks for media_id {media_id}") + log_counter("process_chunks_batch_success", labels={"media_id": media_id}) + except Exception as e: + logging.error(f"Error inserting chunk batch for media_id {media_id}: {e}") + log_counter("process_chunks_batch_error", labels={"media_id": media_id, "error_type": type(e).__name__}) + # Optionally, you could raise an exception here to stop processing + # raise + + logging.info(f"Finished processing all {total_chunks} chunks for media_id {media_id}") + duration = time.time() - start_time + log_histogram("process_chunks_duration", duration, labels={"media_id": media_id}) + log_counter("process_chunks_success", labels={"media_id": media_id}) + except Exception as e: + duration = time.time() - start_time + log_histogram("process_chunks_duration", duration, labels={"media_id": media_id}) + log_counter("process_chunks_error", labels={"media_id": media_id, "error_type": type(e).__name__}) + logging.error(f"Error processing chunks for media_id {media_id}: {e}") # Usage example: @@ -2995,46 +2549,48 @@ def update_media_table(db): # # Workflow Functions +# Workflow Functions def save_workflow_chat_to_db(chat_history, workflow_name, conversation_id=None): - try: - with db.get_connection() as conn: - cursor = conn.cursor() - - if conversation_id is None: - # Create a new conversation - conversation_name = f"{workflow_name}_Workflow_{datetime.now().strftime('%Y%m%d_%H%M%S')}" - cursor.execute(''' - INSERT INTO ChatConversations (media_id, media_name, conversation_name, created_at, updated_at) - VALUES (NULL, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) - ''', (workflow_name, conversation_name)) - conversation_id = cursor.lastrowid - else: - # Update existing conversation - cursor.execute(''' - UPDATE ChatConversations - SET updated_at = CURRENT_TIMESTAMP - WHERE id = ? - ''', (conversation_id,)) - - # Save messages - for user_msg, ai_msg in chat_history: - if user_msg: - cursor.execute(''' - INSERT INTO ChatMessages (conversation_id, sender, message, timestamp) - VALUES (?, 'user', ?, CURRENT_TIMESTAMP) - ''', (conversation_id, user_msg)) - if ai_msg: - cursor.execute(''' - INSERT INTO ChatMessages (conversation_id, sender, message, timestamp) - VALUES (?, 'ai', ?, CURRENT_TIMESTAMP) - ''', (conversation_id, ai_msg)) - - conn.commit() - - return conversation_id, f"Chat saved successfully! Conversation ID: {conversation_id}" - except Exception as e: - logging.error(f"Error saving workflow chat to database: {str(e)}") - return None, f"Error saving chat to database: {str(e)}" + pass +# try: +# with db.get_connection() as conn: +# cursor = conn.cursor() +# +# if conversation_id is None: +# # Create a new conversation +# conversation_name = f"{workflow_name}_Workflow_{datetime.now().strftime('%Y%m%d_%H%M%S')}" +# cursor.execute(''' +# INSERT INTO ChatConversations (media_id, media_name, conversation_name, created_at, updated_at) +# VALUES (NULL, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) +# ''', (workflow_name, conversation_name)) +# conversation_id = cursor.lastrowid +# else: +# # Update existing conversation +# cursor.execute(''' +# UPDATE ChatConversations +# SET updated_at = CURRENT_TIMESTAMP +# WHERE id = ? +# ''', (conversation_id,)) +# +# # Save messages +# for user_msg, ai_msg in chat_history: +# if user_msg: +# cursor.execute(''' +# INSERT INTO ChatMessages (conversation_id, sender, message, timestamp) +# VALUES (?, 'user', ?, CURRENT_TIMESTAMP) +# ''', (conversation_id, user_msg)) +# if ai_msg: +# cursor.execute(''' +# INSERT INTO ChatMessages (conversation_id, sender, message, timestamp) +# VALUES (?, 'ai', ?, CURRENT_TIMESTAMP) +# ''', (conversation_id, ai_msg)) +# +# conn.commit() +# +# return conversation_id, f"Chat saved successfully! Conversation ID: {conversation_id}" +# except Exception as e: +# logging.error(f"Error saving workflow chat to database: {str(e)}") +# return None, f"Error saving chat to database: {str(e)}" def get_workflow_chat(conversation_id): diff --git a/App_Function_Libraries/Gradio_Related.py b/App_Function_Libraries/Gradio_Related.py index c2911040a..4f4e217d0 100644 --- a/App_Function_Libraries/Gradio_Related.py +++ b/App_Function_Libraries/Gradio_Related.py @@ -9,39 +9,43 @@ import logging import os import webbrowser - # # Import 3rd-Party Libraries import gradio as gr # # Local Imports -from App_Function_Libraries.DB.DB_Manager import get_db_config -from App_Function_Libraries.Gradio_UI.Anki_Validation_tab import create_anki_validation_tab +from App_Function_Libraries.DB.DB_Manager import get_db_config, backup_dir +from App_Function_Libraries.DB.RAG_QA_Chat_DB import create_tables +from App_Function_Libraries.Gradio_UI.Anki_tab import create_anki_validation_tab, create_anki_generator_tab from App_Function_Libraries.Gradio_UI.Arxiv_tab import create_arxiv_tab from App_Function_Libraries.Gradio_UI.Audio_ingestion_tab import create_audio_processing_tab +from App_Function_Libraries.Gradio_UI.Backup_RAG_Notes_Character_Chat_tab import create_database_management_interface from App_Function_Libraries.Gradio_UI.Book_Ingestion_tab import create_import_book_tab from App_Function_Libraries.Gradio_UI.Character_Chat_tab import create_character_card_interaction_tab, create_character_chat_mgmt_tab, create_custom_character_card_tab, \ create_character_card_validation_tab, create_export_characters_tab from App_Function_Libraries.Gradio_UI.Character_interaction_tab import create_narrator_controlled_conversation_tab, \ create_multiple_character_chat_tab -from App_Function_Libraries.Gradio_UI.Chat_ui import create_chat_management_tab, \ - create_chat_interface_four, create_chat_interface_multi_api, create_chat_interface_stacked, create_chat_interface +from App_Function_Libraries.Gradio_UI.Chat_ui import create_chat_interface_four, create_chat_interface_multi_api, \ + create_chat_interface_stacked, create_chat_interface from App_Function_Libraries.Gradio_UI.Config_tab import create_config_editor_tab from App_Function_Libraries.Gradio_UI.Explain_summarize_tab import create_summarize_explain_tab -from App_Function_Libraries.Gradio_UI.Export_Functionality import create_export_tab -from App_Function_Libraries.Gradio_UI.Backup_Functionality import create_backup_tab, create_view_backups_tab, \ - create_restore_backup_tab +from App_Function_Libraries.Gradio_UI.Export_Functionality import create_rag_export_tab, create_export_tabs +#from App_Function_Libraries.Gradio_UI.Backup_Functionality import create_backup_tab, create_view_backups_tab, \ +# create_restore_backup_tab from App_Function_Libraries.Gradio_UI.Import_Functionality import create_import_single_prompt_tab, \ - create_import_obsidian_vault_tab, create_import_item_tab, create_import_multiple_prompts_tab + create_import_obsidian_vault_tab, create_import_item_tab, create_import_multiple_prompts_tab, \ + create_conversation_import_tab from App_Function_Libraries.Gradio_UI.Introduction_tab import create_introduction_tab from App_Function_Libraries.Gradio_UI.Keywords import create_view_keywords_tab, create_add_keyword_tab, \ - create_delete_keyword_tab, create_export_keywords_tab + create_delete_keyword_tab, create_export_keywords_tab, create_rag_qa_keywords_tab, create_character_keywords_tab, \ + create_meta_keywords_tab, create_prompt_keywords_tab from App_Function_Libraries.Gradio_UI.Live_Recording import create_live_recording_tab from App_Function_Libraries.Gradio_UI.Llamafile_tab import create_chat_with_llamafile_tab #from App_Function_Libraries.Gradio_UI.MMLU_Pro_tab import create_mmlu_pro_tab from App_Function_Libraries.Gradio_UI.Media_edit import create_prompt_clone_tab, create_prompt_edit_tab, \ create_media_edit_and_clone_tab, create_media_edit_tab from App_Function_Libraries.Gradio_UI.Media_wiki_tab import create_mediawiki_import_tab, create_mediawiki_config_tab +from App_Function_Libraries.Gradio_UI.Mind_Map_tab import create_mindmap_tab from App_Function_Libraries.Gradio_UI.PDF_ingestion_tab import create_pdf_ingestion_tab, create_pdf_ingestion_test_tab from App_Function_Libraries.Gradio_UI.Plaintext_tab_import import create_plain_text_import_tab from App_Function_Libraries.Gradio_UI.Podcast_tab import create_podcast_tab @@ -62,16 +66,19 @@ from App_Function_Libraries.Gradio_UI.Video_transcription_tab import create_video_transcription_tab from App_Function_Libraries.Gradio_UI.View_tab import create_manage_items_tab from App_Function_Libraries.Gradio_UI.Website_scraping_tab import create_website_scraping_tab -from App_Function_Libraries.Gradio_UI.Chat_Workflows import chat_workflows_tab -from App_Function_Libraries.Gradio_UI.View_DB_Items_tab import create_prompt_view_tab, \ - create_view_all_mediadb_with_versions_tab, create_viewing_mediadb_tab, create_view_all_rag_notes_tab, \ - create_viewing_ragdb_tab, create_mediadb_keyword_search_tab, create_ragdb_keyword_items_tab +from App_Function_Libraries.Gradio_UI.Workflows_tab import chat_workflows_tab +from App_Function_Libraries.Gradio_UI.View_DB_Items_tab import create_view_all_mediadb_with_versions_tab, \ + create_viewing_mediadb_tab, create_view_all_rag_notes_tab, create_viewing_ragdb_tab, \ + create_mediadb_keyword_search_tab, create_ragdb_keyword_items_tab +from App_Function_Libraries.Gradio_UI.Prompts_tab import create_prompt_view_tab, create_prompts_export_tab # # Gradio UI Imports from App_Function_Libraries.Gradio_UI.Evaluations_Benchmarks_tab import create_geval_tab, create_infinite_bench_tab from App_Function_Libraries.Gradio_UI.XML_Ingestion_Tab import create_xml_import_tab #from App_Function_Libraries.Local_LLM.Local_LLM_huggingface import create_huggingface_tab from App_Function_Libraries.Local_LLM.Local_LLM_ollama import create_ollama_tab +from App_Function_Libraries.Utils.Utils import load_and_log_configs + # ####################################################################################################################### # Function Definitions @@ -235,6 +242,147 @@ # all_prompts2 = prompts_category_1 + prompts_category_2 + +####################################################################################################################### +# +# Migration Script +import sqlite3 +import uuid +import logging +import os +from datetime import datetime +import shutil + +# def migrate_media_db_to_rag_chat_db(media_db_path, rag_chat_db_path): +# # Check if migration is needed +# if not os.path.exists(media_db_path): +# logging.info("Media DB does not exist. No migration needed.") +# return +# +# # Optional: Check if migration has already been completed +# migration_flag = os.path.join(os.path.dirname(rag_chat_db_path), 'migration_completed.flag') +# if os.path.exists(migration_flag): +# logging.info("Migration already completed. Skipping migration.") +# return +# +# # Backup databases +# backup_database(media_db_path) +# backup_database(rag_chat_db_path) +# +# # Connect to both databases +# try: +# media_conn = sqlite3.connect(media_db_path) +# rag_conn = sqlite3.connect(rag_chat_db_path) +# +# # Enable foreign key support +# media_conn.execute('PRAGMA foreign_keys = ON;') +# rag_conn.execute('PRAGMA foreign_keys = ON;') +# +# media_cursor = media_conn.cursor() +# rag_cursor = rag_conn.cursor() +# +# # Begin transaction +# rag_conn.execute('BEGIN TRANSACTION;') +# +# # Extract conversations from media DB +# media_cursor.execute(''' +# SELECT id, media_id, media_name, conversation_name, created_at, updated_at +# FROM ChatConversations +# ''') +# conversations = media_cursor.fetchall() +# +# for conv in conversations: +# old_conv_id, media_id, media_name, conversation_name, created_at, updated_at = conv +# +# # Convert timestamps if necessary +# created_at = parse_timestamp(created_at) +# updated_at = parse_timestamp(updated_at) +# +# # Generate a new conversation_id +# conversation_id = str(uuid.uuid4()) +# title = conversation_name or (f"{media_name}-chat" if media_name else "Untitled Conversation") +# +# # Insert into conversation_metadata +# rag_cursor.execute(''' +# INSERT INTO conversation_metadata (conversation_id, created_at, last_updated, title, media_id) +# VALUES (?, ?, ?, ?, ?) +# ''', (conversation_id, created_at, updated_at, title, media_id)) +# +# # Extract messages from media DB +# media_cursor.execute(''' +# SELECT sender, message, timestamp +# FROM ChatMessages +# WHERE conversation_id = ? +# ORDER BY timestamp ASC +# ''', (old_conv_id,)) +# messages = media_cursor.fetchall() +# +# for msg in messages: +# sender, content, timestamp = msg +# +# # Convert timestamp if necessary +# timestamp = parse_timestamp(timestamp) +# +# role = sender # Assuming 'sender' is 'user' or 'ai' +# +# # Insert message into rag_qa_chats +# rag_cursor.execute(''' +# INSERT INTO rag_qa_chats (conversation_id, timestamp, role, content) +# VALUES (?, ?, ?, ?) +# ''', (conversation_id, timestamp, role, content)) +# +# # Commit transaction +# rag_conn.commit() +# logging.info("Migration completed successfully.") +# +# # Mark migration as complete +# with open(migration_flag, 'w') as f: +# f.write('Migration completed on ' + datetime.now().isoformat()) +# +# except Exception as e: +# # Rollback transaction in case of error +# rag_conn.rollback() +# logging.error(f"Error during migration: {e}") +# raise +# finally: +# media_conn.close() +# rag_conn.close() + +def backup_database(db_path): + backup_path = db_path + '.backup' + if not os.path.exists(backup_path): + shutil.copyfile(db_path, backup_path) + logging.info(f"Database backed up to {backup_path}") + else: + logging.info(f"Backup already exists at {backup_path}") + +def parse_timestamp(timestamp_value): + """ + Parses the timestamp from the old database and converts it to a standard format. + Adjust this function based on the actual format of your timestamps. + """ + try: + # Attempt to parse ISO format + return datetime.fromisoformat(timestamp_value).isoformat() + except ValueError: + # Handle other timestamp formats if necessary + # For example, if timestamps are in Unix epoch format + try: + timestamp_float = float(timestamp_value) + return datetime.fromtimestamp(timestamp_float).isoformat() + except ValueError: + # Default to current time if parsing fails + logging.warning(f"Unable to parse timestamp '{timestamp_value}', using current time.") + return datetime.now().isoformat() + +# +# End of Migration Script +####################################################################################################################### + + +####################################################################################################################### +# +# Launch UI Function def launch_ui(share_public=None, server_mode=False): webbrowser.open_new_tab('http://127.0.0.1:7860/?__theme=dark') share=share_public @@ -257,6 +405,19 @@ def launch_ui(share_public=None, server_mode=False): } """ + config = load_and_log_configs() + # Get database paths from config + db_config = config['db_config'] + media_db_path = db_config['sqlite_path'] + character_chat_db_path = os.path.join(os.path.dirname(media_db_path), "chatDB.db") + rag_chat_db_path = os.path.join(os.path.dirname(media_db_path), "rag_qa.db") + # Initialize the RAG Chat DB (create tables and update schema) + create_tables() + + # Migrate data from the media DB to the RAG Chat DB + #migrate_media_db_to_rag_chat_db(media_db_path, rag_chat_db_path) + + with gr.Blocks(theme='bethecloud/storj_theme',css=css) as iface: gr.HTML( """ @@ -290,10 +451,6 @@ def launch_ui(share_public=None, server_mode=False): create_arxiv_tab() create_semantic_scholar_tab() - with gr.TabItem("Text Search", id="text search", visible=True): - create_search_tab() - create_search_summaries_tab() - with gr.TabItem("RAG Chat/Search", id="RAG Chat Notes group", visible=True): create_rag_tab() create_rag_qa_chat_tab() @@ -305,8 +462,6 @@ def launch_ui(share_public=None, server_mode=False): create_chat_interface_stacked() create_chat_interface_multi_api() create_chat_interface_four() - create_chat_with_llamafile_tab() - create_chat_management_tab() chat_workflows_tab() with gr.TabItem("Character Chat", id="character chat group", visible=True): @@ -318,51 +473,56 @@ def launch_ui(share_public=None, server_mode=False): create_narrator_controlled_conversation_tab() create_export_characters_tab() - with gr.TabItem("View DB Items", id="view db items group", visible=True): + with gr.TabItem("Writing Tools", id="writing_tools group", visible=True): + from App_Function_Libraries.Gradio_UI.Writing_tab import create_document_feedback_tab + create_document_feedback_tab() + from App_Function_Libraries.Gradio_UI.Writing_tab import create_grammar_style_check_tab + create_grammar_style_check_tab() + from App_Function_Libraries.Gradio_UI.Writing_tab import create_tone_adjustment_tab + create_tone_adjustment_tab() + from App_Function_Libraries.Gradio_UI.Writing_tab import create_creative_writing_tab + create_creative_writing_tab() + from App_Function_Libraries.Gradio_UI.Writing_tab import create_mikupad_tab + create_mikupad_tab() + + with gr.TabItem("Search/View DB Items", id="view db items group", visible=True): + create_search_tab() + create_search_summaries_tab() create_view_all_mediadb_with_versions_tab() create_viewing_mediadb_tab() create_mediadb_keyword_search_tab() create_view_all_rag_notes_tab() create_viewing_ragdb_tab() create_ragdb_keyword_items_tab() - create_prompt_view_tab() with gr.TabItem("Prompts", id='view prompts group', visible=True): - create_prompt_view_tab() - create_prompt_search_tab() - create_prompt_edit_tab() - create_prompt_clone_tab() - create_prompt_suggestion_tab() - - with gr.TabItem("Manage / Edit Existing Items", id="manage group", visible=True): + with gr.Tabs(): + create_prompt_view_tab() + create_prompt_search_tab() + create_prompt_edit_tab() + create_prompt_clone_tab() + create_prompt_suggestion_tab() + create_prompts_export_tab() + + with gr.TabItem("Manage Media DB Items", id="manage group", visible=True): create_media_edit_tab() create_manage_items_tab() create_media_edit_and_clone_tab() - # FIXME - #create_compare_transcripts_tab() with gr.TabItem("Embeddings Management", id="embeddings group", visible=True): create_embeddings_tab() create_view_embeddings_tab() create_purge_embeddings_tab() - with gr.TabItem("Writing Tools", id="writing_tools group", visible=True): - from App_Function_Libraries.Gradio_UI.Writing_tab import create_document_feedback_tab - create_document_feedback_tab() - from App_Function_Libraries.Gradio_UI.Writing_tab import create_grammar_style_check_tab - create_grammar_style_check_tab() - from App_Function_Libraries.Gradio_UI.Writing_tab import create_tone_adjustment_tab - create_tone_adjustment_tab() - from App_Function_Libraries.Gradio_UI.Writing_tab import create_creative_writing_tab - create_creative_writing_tab() - from App_Function_Libraries.Gradio_UI.Writing_tab import create_mikupad_tab - create_mikupad_tab() - with gr.TabItem("Keywords", id="keywords group", visible=True): create_view_keywords_tab() create_add_keyword_tab() create_delete_keyword_tab() create_export_keywords_tab() + create_character_keywords_tab() + create_rag_qa_keywords_tab() + create_meta_keywords_tab() + create_prompt_keywords_tab() with gr.TabItem("Import", id="import group", visible=True): create_import_item_tab() @@ -371,23 +531,38 @@ def launch_ui(share_public=None, server_mode=False): create_import_multiple_prompts_tab() create_mediawiki_import_tab() create_mediawiki_config_tab() + create_conversation_import_tab() with gr.TabItem("Export", id="export group", visible=True): - create_export_tab() - - with gr.TabItem("Backup Management", id="backup group", visible=True): - create_backup_tab() - create_view_backups_tab() - create_restore_backup_tab() + create_export_tabs() + + + with gr.TabItem("Database Management", id="database_management_group", visible=True): + create_database_management_interface( + media_db_config={ + 'db_path': media_db_path, + 'backup_dir': backup_dir + }, + rag_db_config={ + 'db_path': rag_chat_db_path, + 'backup_dir': backup_dir + }, + char_db_config={ + 'db_path': character_chat_db_path, + 'backup_dir': backup_dir + } + ) with gr.TabItem("Utilities", id="util group", visible=True): - # FIXME - #create_anki_generation_tab() - create_anki_validation_tab() + create_mindmap_tab() create_utilities_yt_video_tab() create_utilities_yt_audio_tab() create_utilities_yt_timestamp_tab() + with gr.TabItem("Anki Deck Creation/Validation", id="anki group", visible=True): + create_anki_generator_tab() + create_anki_validation_tab() + with gr.TabItem("Local LLM", id="local llm group", visible=True): create_chat_with_llamafile_tab() create_ollama_tab() diff --git a/App_Function_Libraries/Gradio_UI/Anki_Validation_tab.py b/App_Function_Libraries/Gradio_UI/Anki_Validation_tab.py deleted file mode 100644 index 7560d97b5..000000000 --- a/App_Function_Libraries/Gradio_UI/Anki_Validation_tab.py +++ /dev/null @@ -1,836 +0,0 @@ -# Anki_Validation_tab.py -# Description: Gradio functions for the Anki Validation tab -# -# Imports -from datetime import datetime -import base64 -import json -import logging -import os -from pathlib import Path -import shutil -import sqlite3 -import tempfile -from typing import Dict, Any, Optional, Tuple, List -import zipfile -# -# External Imports -import gradio as gr -#from outlines import models, prompts -# -# Local Imports -from App_Function_Libraries.Gradio_UI.Chat_ui import chat_wrapper -from App_Function_Libraries.Third_Party.Anki import sanitize_html, generate_card_choices, \ - export_cards, load_card_for_editing, validate_flashcards, handle_file_upload, \ - validate_for_ui, update_card_with_validation, update_card_choices, format_validation_result, enhanced_file_upload, \ - handle_validation -from App_Function_Libraries.Utils.Utils import default_api_endpoint, format_api_name, global_api_endpoints -# -############################################################################################################ -# -# Functions: - -# def create_anki_generation_tab(): -# try: -# default_value = None -# if default_api_endpoint: -# if default_api_endpoint in global_api_endpoints: -# default_value = format_api_name(default_api_endpoint) -# else: -# logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") -# except Exception as e: -# logging.error(f"Error setting default API endpoint: {str(e)}") -# default_value = None -# with gr.TabItem("Anki Flashcard Generation", visible=True): -# gr.Markdown("# Anki Flashcard Generation") -# chat_history = gr.State([]) -# generated_cards_state = gr.State({}) -# -# # Add progress tracking -# generation_progress = gr.Progress() -# status_message = gr.Status() -# -# with gr.Row(): -# # Left Column: Generation Controls -# with gr.Column(scale=1): -# gr.Markdown("## Content Input") -# source_text = gr.TextArea( -# label="Source Text or Topic", -# placeholder="Enter the text or topic you want to create flashcards from...", -# lines=5 -# ) -# -# # API Configuration -# api_endpoint = gr.Dropdown( -# choices=["None"] + [format_api_name(api) for api in global_api_endpoints], -# value=default_value, -# label="API for Card Generation" -# ) -# api_key = gr.Textbox(label="API Key (if required)", type="password") -# -# with gr.Accordion("Generation Settings", open=True): -# num_cards = gr.Slider( -# minimum=1, -# maximum=20, -# value=5, -# step=1, -# label="Number of Cards" -# ) -# -# card_types = gr.CheckboxGroup( -# choices=["basic", "cloze", "reverse"], -# value=["basic"], -# label="Card Types to Generate" -# ) -# -# difficulty_level = gr.Radio( -# choices=["beginner", "intermediate", "advanced"], -# value="intermediate", -# label="Difficulty Level" -# ) -# -# subject_area = gr.Dropdown( -# choices=[ -# "general", -# "language_learning", -# "science", -# "mathematics", -# "history", -# "geography", -# "computer_science", -# "custom" -# ], -# value="general", -# label="Subject Area" -# ) -# -# custom_subject = gr.Textbox( -# label="Custom Subject", -# visible=False, -# placeholder="Enter custom subject..." -# ) -# -# with gr.Accordion("Advanced Options", open=False): -# temperature = gr.Slider( -# label="Temperature", -# minimum=0.00, -# maximum=1.0, -# step=0.05, -# value=0.7 -# ) -# -# max_retries = gr.Slider( -# label="Max Retries on Error", -# minimum=1, -# maximum=5, -# step=1, -# value=3 -# ) -# -# include_examples = gr.Checkbox( -# label="Include example usage", -# value=True -# ) -# -# include_mnemonics = gr.Checkbox( -# label="Generate mnemonics", -# value=True -# ) -# -# include_hints = gr.Checkbox( -# label="Include hints", -# value=True -# ) -# -# tag_style = gr.Radio( -# choices=["broad", "specific", "hierarchical"], -# value="specific", -# label="Tag Style" -# ) -# -# system_prompt = gr.Textbox( -# label="System Prompt", -# value="You are an expert at creating effective Anki flashcards.", -# lines=2 -# ) -# -# generate_button = gr.Button("Generate Flashcards") -# regenerate_button = gr.Button("Regenerate", visible=False) -# error_log = gr.TextArea( -# label="Error Log", -# visible=False, -# lines=3 -# ) -# -# # Right Column: Chat Interface and Preview -# with gr.Column(scale=1): -# gr.Markdown("## Interactive Card Generation") -# chatbot = gr.Chatbot(height=400, elem_classes="chatbot-container") -# -# with gr.Row(): -# msg = gr.Textbox( -# label="Chat to refine cards", -# placeholder="Ask questions or request modifications..." -# ) -# submit_chat = gr.Button("Submit") -# -# gr.Markdown("## Generated Cards Preview") -# generated_cards = gr.JSON(label="Generated Flashcards") -# -# with gr.Row(): -# edit_generated = gr.Button("Edit in Validator") -# save_generated = gr.Button("Save to File") -# clear_chat = gr.Button("Clear Chat") -# -# generation_status = gr.Markdown("") -# download_file = gr.File(label="Download Cards", visible=False) -# -# # Helper Functions and Classes -# class AnkiCardGenerator: -# def __init__(self): -# self.schema = { -# "type": "object", -# "properties": { -# "cards": { -# "type": "array", -# "items": { -# "type": "object", -# "properties": { -# "id": {"type": "string"}, -# "type": {"type": "string", "enum": ["basic", "cloze", "reverse"]}, -# "front": {"type": "string"}, -# "back": {"type": "string"}, -# "tags": { -# "type": "array", -# "items": {"type": "string"} -# }, -# "note": {"type": "string"} -# }, -# "required": ["id", "type", "front", "back", "tags"] -# } -# } -# }, -# "required": ["cards"] -# } -# -# self.template = prompts.TextTemplate(""" -# Generate {num_cards} Anki flashcards about: {text} -# -# Requirements: -# - Difficulty: {difficulty} -# - Subject: {subject} -# - Card Types: {card_types} -# - Include Examples: {include_examples} -# - Include Mnemonics: {include_mnemonics} -# - Include Hints: {include_hints} -# - Tag Style: {tag_style} -# -# Each card must have: -# 1. Unique ID starting with CARD_ -# 2. Type (one of: basic, cloze, reverse) -# 3. Clear question/prompt on front -# 4. Comprehensive answer on back -# 5. Relevant tags including subject and difficulty -# 6. Optional note with study tips or mnemonics -# -# For cloze deletions, use the format {{c1::text to be hidden}}. -# -# Ensure each card: -# - Focuses on a single concept -# - Is clear and unambiguous -# - Uses appropriate formatting -# - Has relevant tags -# - Includes requested additional information -# """) -# -# async def generate_with_progress( -# self, -# text: str, -# config: Dict[str, Any], -# progress: gr.Progress -# ) -> GenerationResult: -# try: -# # Initialize progress -# progress(0, desc="Initializing generation...") -# -# # Configure model -# model = models.Model(config["api_endpoint"]) -# -# # Generate with schema validation -# progress(0.3, desc="Generating cards...") -# response = await model.generate( -# self.template, -# schema=self.schema, -# text=text, -# **config -# ) -# -# # Validate response -# progress(0.6, desc="Validating generated cards...") -# validated_cards = self.validate_cards(response) -# -# # Final processing -# progress(0.9, desc="Finalizing...") -# time.sleep(0.5) # Brief pause for UI feedback -# return GenerationResult( -# cards=validated_cards, -# error=None, -# status="Generation completed successfully!", -# progress=1.0 -# ) -# -# except Exception as e: -# logging.error(f"Card generation error: {str(e)}") -# return GenerationResult( -# cards=None, -# error=str(e), -# status=f"Error: {str(e)}", -# progress=1.0 -# ) -# -# def validate_cards(self, cards: Dict[str, Any]) -> Dict[str, Any]: -# """Validate and clean generated cards""" -# if not isinstance(cards, dict) or "cards" not in cards: -# raise ValueError("Invalid card format") -# -# seen_ids = set() -# cleaned_cards = [] -# -# for card in cards["cards"]: -# # Check ID uniqueness -# if card["id"] in seen_ids: -# card["id"] = f"{card['id']}_{len(seen_ids)}" -# seen_ids.add(card["id"]) -# -# # Validate card type -# if card["type"] not in ["basic", "cloze", "reverse"]: -# raise ValueError(f"Invalid card type: {card['type']}") -# -# # Check content -# if not card["front"].strip() or not card["back"].strip(): -# raise ValueError("Empty card content") -# -# # Validate cloze format -# if card["type"] == "cloze" and "{{c1::" not in card["front"]: -# raise ValueError("Invalid cloze format") -# -# # Clean and standardize tags -# if not isinstance(card["tags"], list): -# card["tags"] = [str(card["tags"])] -# card["tags"] = [tag.strip().lower() for tag in card["tags"] if tag.strip()] -# -# cleaned_cards.append(card) -# -# return {"cards": cleaned_cards} -# -# # Initialize generator -# generator = AnkiCardGenerator() -# -# async def generate_flashcards(*args): -# text, num_cards, card_types, difficulty, subject, custom_subject, \ -# include_examples, include_mnemonics, include_hints, tag_style, \ -# temperature, api_endpoint, api_key, system_prompt, max_retries = args -# -# actual_subject = custom_subject if subject == "custom" else subject -# -# config = { -# "num_cards": num_cards, -# "difficulty": difficulty, -# "subject": actual_subject, -# "card_types": card_types, -# "include_examples": include_examples, -# "include_mnemonics": include_mnemonics, -# "include_hints": include_hints, -# "tag_style": tag_style, -# "temperature": temperature, -# "api_endpoint": api_endpoint, -# "api_key": api_key, -# "system_prompt": system_prompt -# } -# -# errors = [] -# retry_count = 0 -# -# while retry_count < max_retries: -# try: -# result = await generator.generate_with_progress(text, config, generation_progress) -# -# if result.error: -# errors.append(f"Attempt {retry_count + 1}: {result.error}") -# retry_count += 1 -# await asyncio.sleep(1) -# continue -# -# return ( -# result.cards, -# gr.update(visible=True), -# result.status, -# gr.update(visible=False), -# [[None, "Cards generated! You can now modify them through chat."]] -# ) -# -# except Exception as e: -# errors.append(f"Attempt {retry_count + 1}: {str(e)}") -# retry_count += 1 -# await asyncio.sleep(1) -# -# error_log = "\n".join(errors) -# return ( -# None, -# gr.update(visible=False), -# "Failed to generate cards after all retries", -# gr.update(value=error_log, visible=True), -# [[None, "Failed to generate cards. Please check the error log."]] -# ) -# -# def save_generated_cards(cards): -# if not cards: -# return "No cards to save", None -# -# try: -# cards_json = json.dumps(cards, indent=2) -# current_time = datetime.now().strftime("%Y%m%d_%H%M%S") -# filename = f"anki_cards_{current_time}.json" -# -# return ( -# "Cards saved successfully!", -# (filename, cards_json, "application/json") -# ) -# except Exception as e: -# logging.error(f"Error saving cards: {e}") -# return f"Error saving cards: {str(e)}", None -# -# def clear_chat_history(): -# return [], [], "Chat cleared" -# -# def toggle_custom_subject(choice): -# return gr.update(visible=choice == "custom") -# -# def send_to_validator(cards): -# if not cards: -# return "No cards to validate" -# try: -# # Here you would integrate with your validation tab -# validated_cards = generator.validate_cards(cards) -# return "Cards validated and sent to validator" -# except Exception as e: -# logging.error(f"Validation error: {e}") -# return f"Validation error: {str(e)}" -# -# # Register callbacks -# subject_area.change( -# fn=toggle_custom_subject, -# inputs=subject_area, -# outputs=custom_subject -# ) -# -# generate_button.click( -# fn=generate_flashcards, -# inputs=[ -# source_text, num_cards, card_types, difficulty_level, -# subject_area, custom_subject, include_examples, -# include_mnemonics, include_hints, tag_style, -# temperature, api_endpoint, api_key, system_prompt, -# max_retries -# ], -# outputs=[ -# generated_cards, -# regenerate_button, -# generation_status, -# error_log, -# chatbot -# ] -# ) -# -# regenerate_button.click( -# fn=generate_flashcards, -# inputs=[ -# source_text, num_cards, card_types, difficulty_level, -# subject_area, custom_subject, include_examples, -# include_mnemonics, include_hints, tag_style, -# temperature, api_endpoint, api_key, system_prompt, -# max_retries -# ], -# outputs=[ -# generated_cards, -# regenerate_button, -# generation_status, -# error_log, -# chatbot -# ] -# ) -# -# clear_chat.click( -# fn=clear_chat_history, -# outputs=[chatbot, chat_history, generation_status] -# ) -# -# edit_generated.click( -# fn=send_to_validator, -# inputs=generated_cards, -# outputs=generation_status -# ) -# -# save_generated.click( -# fn=save_generated_cards, -# inputs=generated_cards, -# outputs=[generation_status, download_file] -# ) -# -# return ( -# source_text, num_cards, card_types, difficulty_level, -# subject_area, custom_subject, include_examples, -# include_mnemonics, include_hints, tag_style, -# api_endpoint, api_key, temperature, system_prompt, -# generate_button, regenerate_button, generated_cards, -# edit_generated, save_generated, clear_chat, -# generation_status, chatbot, msg, submit_chat, -# chat_history, generated_cards_state, download_file, -# error_log, max_retries -# ) - -def create_anki_validation_tab(): - with gr.TabItem("Anki Flashcard Validation", visible=True): - gr.Markdown("# Anki Flashcard Validation and Editor") - - # State variables for internal tracking - current_card_data = gr.State({}) - preview_update_flag = gr.State(False) - - with gr.Row(): - # Left Column: Input and Validation - with gr.Column(scale=1): - gr.Markdown("## Import or Create Flashcards") - - input_type = gr.Radio( - choices=["JSON", "APKG"], - label="Input Type", - value="JSON" - ) - - with gr.Group() as json_input_group: - flashcard_input = gr.TextArea( - label="Enter Flashcards (JSON format)", - placeholder='''{ - "cards": [ - { - "id": "CARD_001", - "type": "basic", - "front": "What is the capital of France?", - "back": "Paris", - "tags": ["geography", "europe"], - "note": "Remember: City of Light" - } - ] -}''', - lines=10 - ) - - import_json = gr.File( - label="Or Import JSON File", - file_types=[".json"] - ) - - with gr.Group(visible=False) as apkg_input_group: - import_apkg = gr.File( - label="Import APKG File", - file_types=[".apkg"] - ) - deck_info = gr.JSON( - label="Deck Information", - visible=False - ) - - validate_button = gr.Button("Validate Flashcards") - - # Right Column: Validation Results and Editor - with gr.Column(scale=1): - gr.Markdown("## Validation Results") - validation_status = gr.Markdown("") - - with gr.Accordion("Validation Rules", open=False): - gr.Markdown(""" - ### Required Fields: - - Unique ID - - Card Type (basic, cloze, reverse) - - Front content - - Back content - - At least one tag - - ### Content Rules: - - No empty fields - - Front side should be a clear question/prompt - - Back side should contain complete answer - - Cloze deletions must have valid syntax - - No duplicate IDs - - ### Image Rules: - - Valid image tags - - Supported formats (JPG, PNG, GIF) - - Base64 encoded or valid URL - - ### APKG-specific Rules: - - Valid SQLite database structure - - Media files properly referenced - - Note types match Anki standards - - Card templates are well-formed - """) - - with gr.Row(): - # Card Editor - gr.Markdown("## Card Editor") - with gr.Row(): - with gr.Column(scale=1): - with gr.Accordion("Edit Individual Cards", open=True): - card_selector = gr.Dropdown( - label="Select Card to Edit", - choices=[], - interactive=True - ) - - card_type = gr.Radio( - choices=["basic", "cloze", "reverse"], - label="Card Type", - value="basic" - ) - - # Front content with preview - with gr.Group(): - gr.Markdown("### Front Content") - front_content = gr.TextArea( - label="Content (HTML supported)", - lines=3 - ) - front_preview = gr.HTML( - label="Preview" - ) - - # Back content with preview - with gr.Group(): - gr.Markdown("### Back Content") - back_content = gr.TextArea( - label="Content (HTML supported)", - lines=3 - ) - back_preview = gr.HTML( - label="Preview" - ) - - tags_input = gr.TextArea( - label="Tags (comma-separated)", - lines=1 - ) - - notes_input = gr.TextArea( - label="Additional Notes", - lines=2 - ) - - with gr.Row(): - update_card_button = gr.Button("Update Card") - delete_card_button = gr.Button("Delete Card", variant="stop") - - with gr.Row(): - with gr.Column(scale=1): - # Export Options - gr.Markdown("## Export Options") - export_format = gr.Radio( - choices=["Anki CSV", "JSON", "Plain Text"], - label="Export Format", - value="Anki CSV" - ) - export_button = gr.Button("Export Valid Cards") - export_file = gr.File(label="Download Validated Cards") - export_status = gr.Markdown("") - with gr.Column(scale=1): - gr.Markdown("## Export Instructions") - gr.Markdown(""" - ### Anki CSV Format: - - Front, Back, Tags, Type, Note - - Use for importing into Anki - - Images preserved as HTML - - ### JSON Format: - - JSON array of cards - - Images as base64 or URLs - - Use for custom processing - - ### Plain Text Format: - - Question and Answer pairs - - Images represented as [IMG] placeholder - - Use for manual review - """) - - def update_preview(content): - """Update preview with sanitized content.""" - if not content: - return "" - return sanitize_html(content) - - # Event handlers - def validation_chain(content: str) -> Tuple[str, List[str]]: - """Combined validation and card choice update.""" - validation_message = validate_for_ui(content) - card_choices = update_card_choices(content) - return validation_message, card_choices - - def delete_card(card_selection, current_content): - """Delete selected card and return updated content.""" - if not card_selection or not current_content: - return current_content, "No card selected", [] - - try: - data = json.loads(current_content) - selected_id = card_selection.split(" - ")[0] - - data['cards'] = [card for card in data['cards'] if card['id'] != selected_id] - new_content = json.dumps(data, indent=2) - - return ( - new_content, - "Card deleted successfully!", - generate_card_choices(new_content) - ) - - except Exception as e: - return current_content, f"Error deleting card: {str(e)}", [] - - def process_validation_result(is_valid, message): - """Process validation result into a formatted markdown string.""" - if is_valid: - return f"✅ {message}" - else: - return f"❌ {message}" - - # Register event handlers - input_type.change( - fn=lambda t: ( - gr.update(visible=t == "JSON"), - gr.update(visible=t == "APKG"), - gr.update(visible=t == "APKG") - ), - inputs=[input_type], - outputs=[json_input_group, apkg_input_group, deck_info] - ) - - # File upload handlers - import_json.upload( - fn=handle_file_upload, - inputs=[import_json, input_type], - outputs=[ - flashcard_input, - deck_info, - validation_status, - card_selector - ] - ) - - import_apkg.upload( - fn=enhanced_file_upload, - inputs=[import_apkg, input_type], - outputs=[ - flashcard_input, - deck_info, - validation_status, - card_selector - ] - ) - - # Validation handler - validate_button.click( - fn=lambda content, input_format: ( - handle_validation(content, input_format), - generate_card_choices(content) if content else [] - ), - inputs=[flashcard_input, input_type], - outputs=[validation_status, card_selector] - ) - - # Card editing handlers - # Card selector change event - card_selector.change( - fn=load_card_for_editing, - inputs=[card_selector, flashcard_input], - outputs=[ - card_type, - front_content, - back_content, - tags_input, - notes_input, - front_preview, - back_preview - ] - ) - - # Live preview updates - front_content.change( - fn=update_preview, - inputs=[front_content], - outputs=[front_preview] - ) - - back_content.change( - fn=update_preview, - inputs=[back_content], - outputs=[back_preview] - ) - - # Card update handler - update_card_button.click( - fn=update_card_with_validation, - inputs=[ - flashcard_input, - card_selector, - card_type, - front_content, - back_content, - tags_input, - notes_input - ], - outputs=[ - flashcard_input, - validation_status, - card_selector - ] - ) - - # Delete card handler - delete_card_button.click( - fn=delete_card, - inputs=[card_selector, flashcard_input], - outputs=[flashcard_input, validation_status, card_selector] - ) - - # Export handler - export_button.click( - fn=export_cards, - inputs=[flashcard_input, export_format], - outputs=[export_status, export_file] - ) - - return ( - flashcard_input, - import_json, - import_apkg, - validate_button, - validation_status, - card_selector, - card_type, - front_content, - back_content, - front_preview, - back_preview, - tags_input, - notes_input, - update_card_button, - delete_card_button, - export_format, - export_button, - export_file, - export_status, - deck_info - ) - -# -# End of Anki_Validation_tab.py -############################################################################################################ diff --git a/App_Function_Libraries/Gradio_UI/Anki_tab.py b/App_Function_Libraries/Gradio_UI/Anki_tab.py new file mode 100644 index 000000000..4f261f9b5 --- /dev/null +++ b/App_Function_Libraries/Gradio_UI/Anki_tab.py @@ -0,0 +1,921 @@ +# Anki_Validation_tab.py +# Description: Gradio functions for the Anki Validation tab +# +# Imports +import json +import logging +import os +import tempfile +from typing import Optional, Tuple, List, Dict +# +# External Imports +import genanki +import gradio as gr +# +# Local Imports +from App_Function_Libraries.Chat.Chat_Functions import approximate_token_count, update_chat_content, save_chat_history, \ + save_chat_history_to_db_wrapper +from App_Function_Libraries.DB.DB_Manager import list_prompts +from App_Function_Libraries.Gradio_UI.Chat_ui import update_dropdown_multiple, chat_wrapper, update_selected_parts, \ + search_conversations, regenerate_last_message, load_conversation, debug_output +from App_Function_Libraries.Third_Party.Anki import sanitize_html, generate_card_choices, \ + export_cards, load_card_for_editing, handle_file_upload, \ + validate_for_ui, update_card_with_validation, update_card_choices, enhanced_file_upload, \ + handle_validation +from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name +# +############################################################################################################ +# +# Functions: + +def create_anki_validation_tab(): + with gr.TabItem("Anki Flashcard Validation", visible=True): + gr.Markdown("# Anki Flashcard Validation and Editor") + + # State variables for internal tracking + current_card_data = gr.State({}) + preview_update_flag = gr.State(False) + + with gr.Row(): + # Left Column: Input and Validation + with gr.Column(scale=1): + gr.Markdown("## Import or Create Flashcards") + + input_type = gr.Radio( + choices=["JSON", "APKG"], + label="Input Type", + value="JSON" + ) + + with gr.Group() as json_input_group: + flashcard_input = gr.TextArea( + label="Enter Flashcards (JSON format)", + placeholder='''{ + "cards": [ + { + "id": "CARD_001", + "type": "basic", + "front": "What is the capital of France?", + "back": "Paris", + "tags": ["geography", "europe"], + "note": "Remember: City of Light" + } + ] +}''', + lines=10 + ) + + import_json = gr.File( + label="Or Import JSON File", + file_types=[".json"] + ) + + with gr.Group(visible=False) as apkg_input_group: + import_apkg = gr.File( + label="Import APKG File", + file_types=[".apkg"] + ) + deck_info = gr.JSON( + label="Deck Information", + visible=False + ) + + validate_button = gr.Button("Validate Flashcards") + + # Right Column: Validation Results and Editor + with gr.Column(scale=1): + gr.Markdown("## Validation Results") + validation_status = gr.Markdown("") + + with gr.Accordion("Validation Rules", open=False): + gr.Markdown(""" + ### Required Fields: + - Unique ID + - Card Type (basic, cloze, reverse) + - Front content + - Back content + - At least one tag + + ### Content Rules: + - No empty fields + - Front side should be a clear question/prompt + - Back side should contain complete answer + - Cloze deletions must have valid syntax + - No duplicate IDs + + ### Image Rules: + - Valid image tags + - Supported formats (JPG, PNG, GIF) + - Base64 encoded or valid URL + + ### APKG-specific Rules: + - Valid SQLite database structure + - Media files properly referenced + - Note types match Anki standards + - Card templates are well-formed + """) + + with gr.Row(): + # Card Editor + gr.Markdown("## Card Editor") + with gr.Row(): + with gr.Column(scale=1): + with gr.Accordion("Edit Individual Cards", open=True): + card_selector = gr.Dropdown( + label="Select Card to Edit", + choices=[], + interactive=True + ) + + card_type = gr.Radio( + choices=["basic", "cloze", "reverse"], + label="Card Type", + value="basic" + ) + + # Front content with preview + with gr.Group(): + gr.Markdown("### Front Content") + front_content = gr.TextArea( + label="Content (HTML supported)", + lines=3 + ) + front_preview = gr.HTML( + label="Preview" + ) + + # Back content with preview + with gr.Group(): + gr.Markdown("### Back Content") + back_content = gr.TextArea( + label="Content (HTML supported)", + lines=3 + ) + back_preview = gr.HTML( + label="Preview" + ) + + tags_input = gr.TextArea( + label="Tags (comma-separated)", + lines=1 + ) + + notes_input = gr.TextArea( + label="Additional Notes", + lines=2 + ) + + with gr.Row(): + update_card_button = gr.Button("Update Card") + delete_card_button = gr.Button("Delete Card", variant="stop") + + with gr.Row(): + with gr.Column(scale=1): + # Export Options + gr.Markdown("## Export Options") + export_format = gr.Radio( + choices=["Anki CSV", "JSON", "Plain Text"], + label="Export Format", + value="Anki CSV" + ) + export_button = gr.Button("Export Valid Cards") + export_file = gr.File(label="Download Validated Cards") + export_status = gr.Markdown("") + with gr.Column(scale=1): + gr.Markdown("## Export Instructions") + gr.Markdown(""" + ### Anki CSV Format: + - Front, Back, Tags, Type, Note + - Use for importing into Anki + - Images preserved as HTML + + ### JSON Format: + - JSON array of cards + - Images as base64 or URLs + - Use for custom processing + + ### Plain Text Format: + - Question and Answer pairs + - Images represented as [IMG] placeholder + - Use for manual review + """) + + def update_preview(content): + """Update preview with sanitized content.""" + if not content: + return "" + return sanitize_html(content) + + # Event handlers + def validation_chain(content: str) -> Tuple[str, List[str]]: + """Combined validation and card choice update.""" + validation_message = validate_for_ui(content) + card_choices = update_card_choices(content) + return validation_message, card_choices + + def delete_card(card_selection, current_content): + """Delete selected card and return updated content.""" + if not card_selection or not current_content: + return current_content, "No card selected", [] + + try: + data = json.loads(current_content) + selected_id = card_selection.split(" - ")[0] + + data['cards'] = [card for card in data['cards'] if card['id'] != selected_id] + new_content = json.dumps(data, indent=2) + + return ( + new_content, + "Card deleted successfully!", + generate_card_choices(new_content) + ) + + except Exception as e: + return current_content, f"Error deleting card: {str(e)}", [] + + def process_validation_result(is_valid, message): + """Process validation result into a formatted markdown string.""" + if is_valid: + return f"✅ {message}" + else: + return f"❌ {message}" + + # Register event handlers + input_type.change( + fn=lambda t: ( + gr.update(visible=t == "JSON"), + gr.update(visible=t == "APKG"), + gr.update(visible=t == "APKG") + ), + inputs=[input_type], + outputs=[json_input_group, apkg_input_group, deck_info] + ) + + # File upload handlers + import_json.upload( + fn=handle_file_upload, + inputs=[import_json, input_type], + outputs=[ + flashcard_input, + deck_info, + validation_status, + card_selector + ] + ) + + import_apkg.upload( + fn=enhanced_file_upload, + inputs=[import_apkg, input_type], + outputs=[ + flashcard_input, + deck_info, + validation_status, + card_selector + ] + ) + + # Validation handler + validate_button.click( + fn=lambda content, input_format: ( + handle_validation(content, input_format), + generate_card_choices(content) if content else [] + ), + inputs=[flashcard_input, input_type], + outputs=[validation_status, card_selector] + ) + + # Card editing handlers + # Card selector change event + card_selector.change( + fn=load_card_for_editing, + inputs=[card_selector, flashcard_input], + outputs=[ + card_type, + front_content, + back_content, + tags_input, + notes_input, + front_preview, + back_preview + ] + ) + + # Live preview updates + front_content.change( + fn=update_preview, + inputs=[front_content], + outputs=[front_preview] + ) + + back_content.change( + fn=update_preview, + inputs=[back_content], + outputs=[back_preview] + ) + + # Card update handler + update_card_button.click( + fn=update_card_with_validation, + inputs=[ + flashcard_input, + card_selector, + card_type, + front_content, + back_content, + tags_input, + notes_input + ], + outputs=[ + flashcard_input, + validation_status, + card_selector + ] + ) + + # Delete card handler + delete_card_button.click( + fn=delete_card, + inputs=[card_selector, flashcard_input], + outputs=[flashcard_input, validation_status, card_selector] + ) + + # Export handler + export_button.click( + fn=export_cards, + inputs=[flashcard_input, export_format], + outputs=[export_status, export_file] + ) + + return ( + flashcard_input, + import_json, + import_apkg, + validate_button, + validation_status, + card_selector, + card_type, + front_content, + back_content, + front_preview, + back_preview, + tags_input, + notes_input, + update_card_button, + delete_card_button, + export_format, + export_button, + export_file, + export_status, + deck_info + ) + + +def create_anki_generator_tab(): + with gr.TabItem("Anki Deck Generator", visible=True): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None + custom_css = """ + .chatbot-container .message-wrap .message { + font-size: 14px !important; + } + """ + with gr.TabItem("LLM Chat & Anki Deck Creation", visible=True): + gr.Markdown("# Chat with an LLM to help you come up with Questions/Answers for an Anki Deck") + chat_history = gr.State([]) + media_content = gr.State({}) + selected_parts = gr.State([]) + conversation_id = gr.State(None) + initial_prompts, total_pages, current_page = list_prompts(page=1, per_page=10) + + with gr.Row(): + with gr.Column(scale=1): + search_query_input = gr.Textbox( + label="Search Query", + placeholder="Enter your search query here..." + ) + search_type_input = gr.Radio( + choices=["Title", "Content", "Author", "Keyword"], + value="Keyword", + label="Search By" + ) + keyword_filter_input = gr.Textbox( + label="Filter by Keywords (comma-separated)", + placeholder="ml, ai, python, etc..." + ) + search_button = gr.Button("Search") + items_output = gr.Dropdown(label="Select Item", choices=[], interactive=True) + item_mapping = gr.State({}) + with gr.Row(): + use_content = gr.Checkbox(label="Use Content") + use_summary = gr.Checkbox(label="Use Summary") + use_prompt = gr.Checkbox(label="Use Prompt") + save_conversation = gr.Checkbox(label="Save Conversation", value=False, visible=True) + with gr.Row(): + temperature = gr.Slider(label="Temperature", minimum=0.00, maximum=1.0, step=0.05, value=0.7) + with gr.Row(): + conversation_search = gr.Textbox(label="Search Conversations") + with gr.Row(): + search_conversations_btn = gr.Button("Search Conversations") + with gr.Row(): + previous_conversations = gr.Dropdown(label="Select Conversation", choices=[], interactive=True) + with gr.Row(): + load_conversations_btn = gr.Button("Load Selected Conversation") + + # Refactored API selection dropdown + api_endpoint = gr.Dropdown( + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Chat Interaction (Optional)" + ) + api_key = gr.Textbox(label="API Key (if required)", type="password") + custom_prompt_checkbox = gr.Checkbox(label="Use a Custom Prompt", + value=False, + visible=True) + preset_prompt_checkbox = gr.Checkbox(label="Use a Pre-set Prompt", + value=False, + visible=True) + with gr.Row(visible=False) as preset_prompt_controls: + prev_prompt_page = gr.Button("Previous") + next_prompt_page = gr.Button("Next") + current_prompt_page_text = gr.Text(f"Page {current_page} of {total_pages}") + current_prompt_page_state = gr.State(value=1) + + preset_prompt = gr.Dropdown( + label="Select Preset Prompt", + choices=initial_prompts + ) + user_prompt = gr.Textbox(label="Custom Prompt", + placeholder="Enter custom prompt here", + lines=3, + visible=False) + system_prompt_input = gr.Textbox(label="System Prompt", + value="You are a helpful AI assitant", + lines=3, + visible=False) + with gr.Column(scale=2): + chatbot = gr.Chatbot(height=800, elem_classes="chatbot-container") + msg = gr.Textbox(label="Enter your message") + submit = gr.Button("Submit") + regenerate_button = gr.Button("Regenerate Last Message") + token_count_display = gr.Number(label="Approximate Token Count", value=0, interactive=False) + clear_chat_button = gr.Button("Clear Chat") + + chat_media_name = gr.Textbox(label="Custom Chat Name(optional)") + save_chat_history_to_db = gr.Button("Save Chat History to DataBase") + save_status = gr.Textbox(label="Save Status", interactive=False) + save_chat_history_as_file = gr.Button("Save Chat History as File") + download_file = gr.File(label="Download Chat History") + + search_button.click( + fn=update_dropdown_multiple, + inputs=[search_query_input, search_type_input, keyword_filter_input], + outputs=[items_output, item_mapping] + ) + + def update_prompt_visibility(custom_prompt_checked, preset_prompt_checked): + user_prompt_visible = custom_prompt_checked + system_prompt_visible = custom_prompt_checked + preset_prompt_visible = preset_prompt_checked + preset_prompt_controls_visible = preset_prompt_checked + return ( + gr.update(visible=user_prompt_visible, interactive=user_prompt_visible), + gr.update(visible=system_prompt_visible, interactive=system_prompt_visible), + gr.update(visible=preset_prompt_visible, interactive=preset_prompt_visible), + gr.update(visible=preset_prompt_controls_visible) + ) + + def update_prompt_page(direction, current_page_val): + new_page = current_page_val + direction + if new_page < 1: + new_page = 1 + prompts, total_pages, _ = list_prompts(page=new_page, per_page=20) + if new_page > total_pages: + new_page = total_pages + prompts, total_pages, _ = list_prompts(page=new_page, per_page=20) + return ( + gr.update(choices=prompts), + gr.update(value=f"Page {new_page} of {total_pages}"), + new_page + ) + + def clear_chat(): + return [], None # Return empty list for chatbot and None for conversation_id + + custom_prompt_checkbox.change( + update_prompt_visibility, + inputs=[custom_prompt_checkbox, preset_prompt_checkbox], + outputs=[user_prompt, system_prompt_input, preset_prompt, preset_prompt_controls] + ) + + preset_prompt_checkbox.change( + update_prompt_visibility, + inputs=[custom_prompt_checkbox, preset_prompt_checkbox], + outputs=[user_prompt, system_prompt_input, preset_prompt, preset_prompt_controls] + ) + + prev_prompt_page.click( + lambda x: update_prompt_page(-1, x), + inputs=[current_prompt_page_state], + outputs=[preset_prompt, current_prompt_page_text, current_prompt_page_state] + ) + + next_prompt_page.click( + lambda x: update_prompt_page(1, x), + inputs=[current_prompt_page_state], + outputs=[preset_prompt, current_prompt_page_text, current_prompt_page_state] + ) + + submit.click( + chat_wrapper, + inputs=[msg, chatbot, media_content, selected_parts, api_endpoint, api_key, user_prompt, + conversation_id, + save_conversation, temperature, system_prompt_input], + outputs=[msg, chatbot, conversation_id] + ).then( # Clear the message box after submission + lambda x: gr.update(value=""), + inputs=[chatbot], + outputs=[msg] + ).then( # Clear the user prompt after the first message + lambda: (gr.update(value=""), gr.update(value="")), + outputs=[user_prompt, system_prompt_input] + ).then( + lambda history: approximate_token_count(history), + inputs=[chatbot], + outputs=[token_count_display] + ) + + + clear_chat_button.click( + clear_chat, + outputs=[chatbot, conversation_id] + ) + + items_output.change( + update_chat_content, + inputs=[items_output, use_content, use_summary, use_prompt, item_mapping], + outputs=[media_content, selected_parts] + ) + + use_content.change(update_selected_parts, inputs=[use_content, use_summary, use_prompt], + outputs=[selected_parts]) + use_summary.change(update_selected_parts, inputs=[use_content, use_summary, use_prompt], + outputs=[selected_parts]) + use_prompt.change(update_selected_parts, inputs=[use_content, use_summary, use_prompt], + outputs=[selected_parts]) + items_output.change(debug_output, inputs=[media_content, selected_parts], outputs=[]) + + search_conversations_btn.click( + search_conversations, + inputs=[conversation_search], + outputs=[previous_conversations] + ) + + load_conversations_btn.click( + clear_chat, + outputs=[chatbot, chat_history] + ).then( + load_conversation, + inputs=[previous_conversations], + outputs=[chatbot, conversation_id] + ) + + previous_conversations.change( + load_conversation, + inputs=[previous_conversations], + outputs=[chat_history] + ) + + save_chat_history_as_file.click( + save_chat_history, + inputs=[chatbot, conversation_id], + outputs=[download_file] + ) + + save_chat_history_to_db.click( + save_chat_history_to_db_wrapper, + inputs=[chatbot, conversation_id, media_content, chat_media_name], + outputs=[conversation_id, gr.Textbox(label="Save Status")] + ) + + regenerate_button.click( + regenerate_last_message, + inputs=[chatbot, media_content, selected_parts, api_endpoint, api_key, user_prompt, temperature, + system_prompt_input], + outputs=[chatbot, save_status] + ).then( + lambda history: approximate_token_count(history), + inputs=[chatbot], + outputs=[token_count_display] + ) + gr.Markdown("# Create Anki Deck") + + with gr.Row(): + # Left Column: Deck Settings + with gr.Column(scale=1): + gr.Markdown("## Deck Settings") + deck_name = gr.Textbox( + label="Deck Name", + placeholder="My Study Deck", + value="My Study Deck" + ) + + deck_description = gr.Textbox( + label="Deck Description", + placeholder="Description of your deck", + lines=2 + ) + + note_type = gr.Radio( + choices=["Basic", "Basic (and reversed)", "Cloze"], + label="Note Type", + value="Basic" + ) + + # Card Fields based on note type + with gr.Group() as basic_fields: + front_template = gr.Textbox( + label="Front Template (HTML)", + value="{{Front}}", + lines=3 + ) + back_template = gr.Textbox( + label="Back Template (HTML)", + value="{{FrontSide}}
{{Back}}", + lines=3 + ) + + with gr.Group() as cloze_fields: + cloze_template = gr.Textbox( + label="Cloze Template (HTML)", + value="{{cloze:Text}}", + lines=3, + visible=False + ) + + css_styling = gr.Textbox( + label="Card Styling (CSS)", + value=".card {\n font-family: arial;\n font-size: 20px;\n text-align: center;\n color: black;\n background-color: white;\n}\n\n.cloze {\n font-weight: bold;\n color: blue;\n}", + lines=5 + ) + + # Right Column: Card Creation + with gr.Column(scale=1): + gr.Markdown("## Add Cards") + + with gr.Group() as basic_input: + front_content = gr.TextArea( + label="Front Content", + placeholder="Question or prompt", + lines=3 + ) + back_content = gr.TextArea( + label="Back Content", + placeholder="Answer", + lines=3 + ) + + with gr.Group() as cloze_input: + cloze_content = gr.TextArea( + label="Cloze Content", + placeholder="Text with {{c1::cloze}} deletions", + lines=3, + visible=False + ) + + tags_input = gr.TextArea( + label="Tags (comma-separated)", + placeholder="tag1, tag2, tag3", + lines=1 + ) + + add_card_btn = gr.Button("Add Card") + + cards_list = gr.JSON( + label="Cards in Deck", + value={"cards": []} + ) + + clear_cards_btn = gr.Button("Clear All Cards", variant="stop") + + with gr.Row(): + generate_deck_btn = gr.Button("Generate Deck", variant="primary") + download_deck = gr.File(label="Download Deck") + generation_status = gr.Markdown("") + + def update_note_type_fields(note_type: str): + if note_type == "Cloze": + return { + basic_input: gr.update(visible=False), + cloze_input: gr.update(visible=True), + basic_fields: gr.update(visible=False), + cloze_fields: gr.update(visible=True) + } + else: + return { + basic_input: gr.update(visible=True), + cloze_input: gr.update(visible=False), + basic_fields: gr.update(visible=True), + cloze_fields: gr.update(visible=False) + } + + def add_card(note_type: str, front: str, back: str, cloze: str, tags: str, current_cards: Dict[str, List]): + if not current_cards: + current_cards = {"cards": []} + + cards_data = current_cards["cards"] + + # Process tags + card_tags = [tag.strip() for tag in tags.split(',') if tag.strip()] + + new_card = { + "id": f"CARD_{len(cards_data) + 1}", + "tags": card_tags + } + + if note_type == "Cloze": + if not cloze or "{{c" not in cloze: + return current_cards, "❌ Invalid cloze format. Use {{c1::text}} syntax." + new_card.update({ + "type": "cloze", + "content": cloze + }) + else: + if not front or not back: + return current_cards, "❌ Both front and back content are required." + new_card.update({ + "type": "basic", + "front": front, + "back": back, + "is_reverse": note_type == "Basic (and reversed)" + }) + + cards_data.append(new_card) + return {"cards": cards_data}, "✅ Card added successfully!" + + def clear_cards() -> Tuple[Dict[str, List], str]: + return {"cards": []}, "✅ All cards cleared!" + + def generate_anki_deck( + deck_name: str, + deck_description: str, + note_type: str, + front_template: str, + back_template: str, + cloze_template: str, + css: str, + cards_data: Dict[str, List] + ) -> Tuple[Optional[str], str]: + try: + if not cards_data or not cards_data.get("cards"): + return None, "❌ No cards to generate deck from!" + + # Create model based on note type + if note_type == "Cloze": + model = genanki.Model( + 1483883320, # Random model ID + 'Cloze Model', + fields=[ + {'name': 'Text'}, + {'name': 'Back Extra'} + ], + templates=[{ + 'name': 'Cloze Card', + 'qfmt': cloze_template, + 'afmt': cloze_template + '

{{Back Extra}}' + }], + css=css, + # FIXME CLOZE DOESNT EXIST + model_type=1 + ) + else: + templates = [{ + 'name': 'Card 1', + 'qfmt': front_template, + 'afmt': back_template + }] + + if note_type == "Basic (and reversed)": + templates.append({ + 'name': 'Card 2', + 'qfmt': '{{Back}}', + 'afmt': '{{FrontSide}}
{{Front}}' + }) + + model = genanki.Model( + 1607392319, # Random model ID + 'Basic Model', + fields=[ + {'name': 'Front'}, + {'name': 'Back'} + ], + templates=templates, + css=css + ) + + # Create deck + deck = genanki.Deck( + 2059400110, # Random deck ID + deck_name, + description=deck_description + ) + + # Add cards to deck + for card in cards_data["cards"]: + if card["type"] == "cloze": + note = genanki.Note( + model=model, + fields=[card["content"], ""], + tags=card["tags"] + ) + else: + note = genanki.Note( + model=model, + fields=[card["front"], card["back"]], + tags=card["tags"] + ) + deck.add_note(note) + + # Save deck to temporary file + temp_dir = tempfile.mkdtemp() + deck_path = os.path.join(temp_dir, f"{deck_name}.apkg") + genanki.Package(deck).write_to_file(deck_path) + + return deck_path, "✅ Deck generated successfully!" + + except Exception as e: + return None, f"❌ Error generating deck: {str(e)}" + + # Register event handlers + note_type.change( + fn=update_note_type_fields, + inputs=[note_type], + outputs=[basic_input, cloze_input, basic_fields, cloze_fields] + ) + + add_card_btn.click( + fn=add_card, + inputs=[ + note_type, + front_content, + back_content, + cloze_content, + tags_input, + cards_list + ], + outputs=[cards_list, generation_status] + ) + + clear_cards_btn.click( + fn=clear_cards, + inputs=[], + outputs=[cards_list, generation_status] + ) + + generate_deck_btn.click( + fn=generate_anki_deck, + inputs=[ + deck_name, + deck_description, + note_type, + front_template, + back_template, + cloze_template, + css_styling, + cards_list + ], + outputs=[download_deck, generation_status] + ) + + + return ( + deck_name, + deck_description, + note_type, + front_template, + back_template, + cloze_template, + css_styling, + front_content, + back_content, + cloze_content, + tags_input, + cards_list, + add_card_btn, + clear_cards_btn, + generate_deck_btn, + download_deck, + generation_status + ) + +# +# End of Anki_Validation_tab.py +############################################################################################################ diff --git a/App_Function_Libraries/Gradio_UI/Audio_ingestion_tab.py b/App_Function_Libraries/Gradio_UI/Audio_ingestion_tab.py index a9317ba17..67cf2a22b 100644 --- a/App_Function_Libraries/Gradio_UI/Audio_ingestion_tab.py +++ b/App_Function_Libraries/Gradio_UI/Audio_ingestion_tab.py @@ -9,7 +9,7 @@ # # Local Imports from App_Function_Libraries.Audio.Audio_Files import process_audio_files -from App_Function_Libraries.DB.DB_Manager import load_preset_prompts +from App_Function_Libraries.DB.DB_Manager import list_prompts from App_Function_Libraries.Gradio_UI.Chat_ui import update_user_prompt from App_Function_Libraries.Gradio_UI.Gradio_Shared import whisper_models from App_Function_Libraries.Utils.Utils import cleanup_temp_files, default_api_endpoint, global_api_endpoints, \ @@ -60,54 +60,133 @@ def create_audio_processing_tab(): keep_timestamps_input = gr.Checkbox(label="Keep Timestamps", value=True) with gr.Row(): - custom_prompt_checkbox = gr.Checkbox(label="Use a Custom Prompt", - value=False, - visible=True) - preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", - value=False, - visible=True) + custom_prompt_checkbox = gr.Checkbox( + label="Use a Custom Prompt", + value=False, + visible=True + ) + preset_prompt_checkbox = gr.Checkbox( + label="Use a pre-set Prompt", + value=False, + visible=True + ) + + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + with gr.Row(): - preset_prompt = gr.Dropdown(label="Select Preset Prompt", - choices=load_preset_prompts(), - visible=False) + # Add pagination controls + preset_prompt = gr.Dropdown( + label="Select Preset Prompt", + choices=[], + visible=False + ) with gr.Row(): - custom_prompt_input = gr.Textbox(label="Custom Prompt", - placeholder="Enter custom prompt here", - lines=3, - visible=False) + prev_page_button = gr.Button("Previous Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + next_page_button = gr.Button("Next Page", visible=False) + with gr.Row(): - system_prompt_input = gr.Textbox(label="System Prompt", - value="""You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] -**Bulleted Note Creation Guidelines** - -**Headings**: -- Based on referenced topics, not categories like quotes or terms -- Surrounded by **bold** formatting -- Not listed as bullet points -- No space between headings and list items underneath - -**Emphasis**: -- **Important terms** set in bold font -- **Text ending in a colon**: also bolded - -**Review**: -- Ensure adherence to specified format -- Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST] -""", - lines=3, - visible=False) + custom_prompt_input = gr.Textbox( + label="Custom Prompt", + placeholder="Enter custom prompt here", + lines=3, + visible=False + ) + with gr.Row(): + system_prompt_input = gr.Textbox( + label="System Prompt", + value="""You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhere to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] + **Bulleted Note Creation Guidelines** + + **Headings**: + - Based on referenced topics, not categories like quotes or terms + - Surrounded by **bold** formatting + - Not listed as bullet points + - No space between headings and list items underneath + + **Emphasis**: + - **Important terms** set in bold font + - **Text ending in a colon**: also bolded + + **Review**: + - Ensure adherence to specified format + - Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST] + """, + lines=3, + visible=False + ) custom_prompt_checkbox.change( fn=lambda x: (gr.update(visible=x), gr.update(visible=x)), inputs=[custom_prompt_checkbox], outputs=[custom_prompt_input, system_prompt_input] ) + + # Handle preset prompt checkbox change + def on_preset_prompt_checkbox_change(is_checked): + if is_checked: + prompts, total_pages, current_page = list_prompts(page=1, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(visible=True, interactive=True, choices=prompts), # preset_prompt + gr.update(visible=True), # prev_page_button + gr.update(visible=True), # next_page_button + gr.update(value=page_display_text, visible=True), # page_display + current_page, # current_page_state + total_pages # total_pages_state + ) + else: + return ( + gr.update(visible=False, interactive=False), # preset_prompt + gr.update(visible=False), # prev_page_button + gr.update(visible=False), # next_page_button + gr.update(visible=False), # page_display + 1, # current_page_state + 1 # total_pages_state + ) + preset_prompt_checkbox.change( - fn=lambda x: gr.update(visible=x), + fn=on_preset_prompt_checkbox_change, inputs=[preset_prompt_checkbox], - outputs=[preset_prompt] + outputs=[preset_prompt, prev_page_button, next_page_button, page_display, current_page_state, total_pages_state] + ) + + # Pagination button functions + def on_prev_page_click(current_page, total_pages): + new_page = max(current_page - 1, 1) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text), + current_page + ) + + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text), + current_page + ) + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] ) + # Update prompts when a preset is selected def update_prompts(preset_name): prompts = update_user_prompt(preset_name) return ( @@ -117,7 +196,7 @@ def update_prompts(preset_name): preset_prompt.change( update_prompts, - inputs=preset_prompt, + inputs=[preset_prompt], outputs=[custom_prompt_input, system_prompt_input] ) # Refactored API selection dropdown diff --git a/App_Function_Libraries/Gradio_UI/Backup_Functionality.py b/App_Function_Libraries/Gradio_UI/Backup_Functionality.py index 672975da2..b3b059e4f 100644 --- a/App_Function_Libraries/Gradio_UI/Backup_Functionality.py +++ b/App_Function_Libraries/Gradio_UI/Backup_Functionality.py @@ -14,7 +14,7 @@ # # Functions: -def create_backup(): +def create_db_backup(): backup_file = create_automated_backup(db_path, backup_dir) return f"Backup created: {backup_file}" @@ -42,18 +42,7 @@ def create_backup_tab(): create_button = gr.Button("Create Backup") create_output = gr.Textbox(label="Result") with gr.Column(): - create_button.click(create_backup, inputs=[], outputs=create_output) - - -def create_view_backups_tab(): - with gr.TabItem("View Backups", visible=True): - gr.Markdown("# Browse available backups") - with gr.Row(): - with gr.Column(): - view_button = gr.Button("View Backups") - with gr.Column(): - backup_list = gr.Textbox(label="Available Backups") - view_button.click(list_backups, inputs=[], outputs=backup_list) + create_button.click(create_db_backup, inputs=[], outputs=create_output) def create_restore_backup_tab(): diff --git a/App_Function_Libraries/Gradio_UI/Backup_RAG_Notes_Character_Chat_tab.py b/App_Function_Libraries/Gradio_UI/Backup_RAG_Notes_Character_Chat_tab.py new file mode 100644 index 000000000..805d360ab --- /dev/null +++ b/App_Function_Libraries/Gradio_UI/Backup_RAG_Notes_Character_Chat_tab.py @@ -0,0 +1,195 @@ +# Backup_Functionality.py +# Functionality for managing database backups +# +# Imports: +import os +import shutil +import gradio as gr +from typing import Dict, List +# +# Local Imports: +from App_Function_Libraries.DB.DB_Manager import create_automated_backup +from App_Function_Libraries.DB.DB_Backups import create_backup, create_incremental_backup, restore_single_db_backup + + +# +# End of Imports +####################################################################################################################### +# +# Functions: + +def get_db_specific_backups(backup_dir: str, db_name: str) -> List[str]: + """Get list of backups specific to a database.""" + all_backups = [f for f in os.listdir(backup_dir) if f.endswith(('.db', '.sqlib'))] + db_specific_backups = [ + backup for backup in all_backups + if backup.startswith(f"{db_name}_") + ] + return sorted(db_specific_backups, reverse=True) # Most recent first + +def create_backup_tab(db_path: str, backup_dir: str, db_name: str): + """Create the backup creation tab for a database.""" + gr.Markdown("## Create Database Backup") + gr.Markdown(f"This will create a backup in the directory: `{backup_dir}`") + with gr.Row(): + with gr.Column(): + #automated_backup_btn = gr.Button("Create Simple Backup") + full_backup_btn = gr.Button("Create Full Backup") + incr_backup_btn = gr.Button("Create Incremental Backup") + with gr.Column(): + backup_output = gr.Textbox(label="Result") + + def create_db_backup(): + backup_file = create_automated_backup(db_path, backup_dir) + return f"Backup created: {backup_file}" + + # automated_backup_btn.click( + # fn=create_db_backup, + # inputs=[], + # outputs=[backup_output] + # ) + full_backup_btn.click( + fn=lambda: create_backup(db_path, backup_dir, db_name), + inputs=[], + outputs=[backup_output] + ) + incr_backup_btn.click( + fn=lambda: create_incremental_backup(db_path, backup_dir, db_name), + inputs=[], + outputs=[backup_output] + ) + +def create_view_backups_tab(backup_dir: str, db_name: str): + """Create the backup viewing tab for a database.""" + gr.Markdown("## Available Backups") + with gr.Row(): + with gr.Column(): + view_btn = gr.Button("Refresh Backup List") + with gr.Column(): + backup_list = gr.Textbox(label="Available Backups") + + def list_db_backups(): + """List backups specific to this database.""" + backups = get_db_specific_backups(backup_dir, db_name) + return "\n".join(backups) if backups else f"No backups found for {db_name} database" + + view_btn.click( + fn=list_db_backups, + inputs=[], + outputs=[backup_list] + ) + +def validate_backup_name(backup_name: str, db_name: str) -> bool: + """Validate that the backup name matches the database being restored.""" + # Check if backup name starts with the database name prefix and has valid extension + valid_prefixes = [ + f"{db_name}_backup_", # Full backup prefix + f"{db_name}_incremental_" # Incremental backup prefix + ] + has_valid_prefix = any(backup_name.startswith(prefix) for prefix in valid_prefixes) + has_valid_extension = backup_name.endswith(('.db', '.sqlib')) + return has_valid_prefix and has_valid_extension + +def create_restore_backup_tab(db_path: str, backup_dir: str, db_name: str): + """Create the backup restoration tab for a database.""" + gr.Markdown("## Restore Database") + gr.Markdown("⚠️ **Warning**: Restoring a backup will overwrite the current database.") + with gr.Row(): + with gr.Column(): + backup_input = gr.Textbox(label="Backup Filename") + restore_btn = gr.Button("Restore", variant="primary") + with gr.Column(): + restore_output = gr.Textbox(label="Result") + + def secure_restore(backup_name: str) -> str: + """Restore backup with validation checks.""" + if not backup_name: + return "Please enter a backup filename" + + # Validate backup name format + if not validate_backup_name(backup_name, db_name): + return f"Invalid backup file. Please select a backup file that starts with '{db_name}_backup_' or '{db_name}_incremental_'" + + # Check if backup exists + backup_path = os.path.join(backup_dir, backup_name) + if not os.path.exists(backup_path): + return f"Backup file not found: {backup_name}" + + # Proceed with restore + return restore_single_db_backup(db_path, backup_dir, db_name, backup_name) + + restore_btn.click( + fn=secure_restore, + inputs=[backup_input], + outputs=[restore_output] + ) + +def create_media_db_tabs(db_config: Dict[str, str]): + """Create all tabs for the Media database.""" + create_backup_tab( + db_path=db_config['db_path'], + backup_dir=db_config['backup_dir'], + db_name='media' + ) + create_view_backups_tab( + backup_dir=db_config['backup_dir'], + db_name='media' + ) + create_restore_backup_tab( + db_path=db_config['db_path'], + backup_dir=db_config['backup_dir'], + db_name='media' + ) + +def create_rag_chat_tabs(db_config: Dict[str, str]): + """Create all tabs for the RAG Chat database.""" + create_backup_tab( + db_path=db_config['db_path'], + backup_dir=db_config['backup_dir'], + db_name='rag_qa' # Updated to match DB_Manager.py + ) + create_view_backups_tab( + backup_dir=db_config['backup_dir'], + db_name='rag_qa' # Updated to match DB_Manager.py + ) + create_restore_backup_tab( + db_path=db_config['db_path'], + backup_dir=db_config['backup_dir'], + db_name='rag_qa' # Updated to match DB_Manager.py + ) + +def create_character_chat_tabs(db_config: Dict[str, str]): + """Create all tabs for the Character Chat database.""" + create_backup_tab( + db_path=db_config['db_path'], + backup_dir=db_config['backup_dir'], + db_name='chatDB' # Updated to match DB_Manager.py + ) + create_view_backups_tab( + backup_dir=db_config['backup_dir'], + db_name='chatDB' # Updated to match DB_Manager.py + ) + create_restore_backup_tab( + db_path=db_config['db_path'], + backup_dir=db_config['backup_dir'], + db_name='chatDB' + ) + +def create_database_management_interface( + media_db_config: Dict[str, str], + rag_db_config: Dict[str, str], + char_db_config: Dict[str, str] +): + """Create the main database management interface with tabs for each database.""" + with gr.TabItem("Media Database", id="media_db_group", visible=True): + create_media_db_tabs(media_db_config) + + with gr.TabItem("RAG Chat Database", id="rag_chat_group", visible=True): + create_rag_chat_tabs(rag_db_config) + + with gr.TabItem("Character Chat Database", id="character_chat_group", visible=True): + create_character_chat_tabs(char_db_config) + +# +# End of Functions +####################################################################################################################### diff --git a/App_Function_Libraries/Gradio_UI/Book_Ingestion_tab.py b/App_Function_Libraries/Gradio_UI/Book_Ingestion_tab.py index 86a2b0488..aa905dea7 100644 --- a/App_Function_Libraries/Gradio_UI/Book_Ingestion_tab.py +++ b/App_Function_Libraries/Gradio_UI/Book_Ingestion_tab.py @@ -8,24 +8,19 @@ # #################### # Imports +import logging # # External Imports -import logging - import gradio as gr # # Local Imports -from App_Function_Libraries.Books.Book_Ingestion_Lib import process_zip_file, import_epub, import_file_handler +from App_Function_Libraries.Books.Book_Ingestion_Lib import import_file_handler from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name - - # ######################################################################################################################## # # Functions: - - def create_import_book_tab(): try: default_value = None @@ -37,42 +32,60 @@ def create_import_book_tab(): except Exception as e: logging.error(f"Error setting default API endpoint: {str(e)}") default_value = None + with gr.TabItem("Ebook(epub) Files", visible=True): with gr.Row(): with gr.Column(): gr.Markdown("# Import .epub files") - gr.Markdown("Upload a single .epub file or a .zip file containing multiple .epub files") + gr.Markdown("Upload multiple .epub files or a .zip file containing multiple .epub files") gr.Markdown( "🔗 **How to remove DRM from your ebooks:** [Reddit Guide](https://www.reddit.com/r/Calibre/comments/1ck4w8e/2024_guide_on_removing_drm_from_kobo_kindle_ebooks/)") - import_file = gr.File(label="Upload file for import", - file_types=[".epub", ".zip", ".html", ".htm", ".xml", ".opml"]) - title_input = gr.Textbox(label="Title", placeholder="Enter the title of the content (for single files)") - author_input = gr.Textbox(label="Author", placeholder="Enter the author's name (for single files)") - keywords_input = gr.Textbox(label="Keywords (like genre or publish year)", - placeholder="Enter keywords, comma-separated") - system_prompt_input = gr.Textbox(label="System Prompt", lines=3, - value="""" - You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] - **Bulleted Note Creation Guidelines** - **Headings**: - - Based on referenced topics, not categories like quotes or terms - - Surrounded by **bold** formatting - - Not listed as bullet points - - No space between headings and list items underneath - - **Emphasis**: - - **Important terms** set in bold font - - **Text ending in a colon**: also bolded + # Updated to support multiple files + import_files = gr.File( + label="Upload files for import", + file_count="multiple", + file_types=[".epub", ".zip", ".html", ".htm", ".xml", ".opml"] + ) - **Review**: - - Ensure adherence to specified format - - Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST] - """, ) - custom_prompt_input = gr.Textbox(label="Custom User Prompt", - placeholder="Enter a custom user prompt for summarization (optional)") + # Optional fields for overriding auto-extracted metadata + author_input = gr.Textbox( + label="Author Override (optional)", + placeholder="Enter author name to override auto-extracted metadata" + ) + keywords_input = gr.Textbox( + label="Keywords (like genre or publish year)", + placeholder="Enter keywords, comma-separated - will be applied to all uploaded books" + ) + system_prompt_input = gr.Textbox( + label="System Prompt", + lines=3, + value="""" + You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] + **Bulleted Note Creation Guidelines** + + **Headings**: + - Based on referenced topics, not categories like quotes or terms + - Surrounded by **bold** formatting + - Not listed as bullet points + - No space between headings and list items underneath + + **Emphasis**: + - **Important terms** set in bold font + - **Text ending in a colon**: also bolded + + **Review**: + - Ensure adherence to specified format + - Do not reference these instructions in your response.[INST] + """ + ) + custom_prompt_input = gr.Textbox( + label="Custom User Prompt", + placeholder="Enter a custom user prompt for summarization (optional)" + ) auto_summarize_checkbox = gr.Checkbox(label="Auto-summarize", value=False) - # Refactored API selection dropdown + + # API configuration api_name_input = gr.Dropdown( choices=["None"] + [format_api_name(api) for api in global_api_endpoints], value=default_value, @@ -81,13 +94,27 @@ def create_import_book_tab(): api_key_input = gr.Textbox(label="API Key", type="password") # Chunking options - max_chunk_size = gr.Slider(minimum=100, maximum=2000, value=500, step=50, label="Max Chunk Size") - chunk_overlap = gr.Slider(minimum=0, maximum=500, value=200, step=10, label="Chunk Overlap") - custom_chapter_pattern = gr.Textbox(label="Custom Chapter Pattern (optional)", - placeholder="Enter a custom regex pattern for chapter detection") + max_chunk_size = gr.Slider( + minimum=100, + maximum=2000, + value=500, + step=50, + label="Max Chunk Size" + ) + chunk_overlap = gr.Slider( + minimum=0, + maximum=500, + value=200, + step=10, + label="Chunk Overlap" + ) + custom_chapter_pattern = gr.Textbox( + label="Custom Chapter Pattern (optional)", + placeholder="Enter a custom regex pattern for chapter detection" + ) + import_button = gr.Button("Import eBooks") - import_button = gr.Button("Import eBook(s)") with gr.Column(): with gr.Row(): import_output = gr.Textbox(label="Import Status", lines=10, interactive=False) @@ -95,10 +122,10 @@ def create_import_book_tab(): import_button.click( fn=import_file_handler, inputs=[ - import_file, - title_input, + import_files, # Now handles multiple files author_input, keywords_input, + system_prompt_input, custom_prompt_input, auto_summarize_checkbox, api_name_input, @@ -110,8 +137,8 @@ def create_import_book_tab(): outputs=import_output ) - return import_file, title_input, author_input, keywords_input, system_prompt_input, custom_prompt_input, auto_summarize_checkbox, api_name_input, api_key_input, import_button, import_output + return import_files, author_input, keywords_input, system_prompt_input, custom_prompt_input, auto_summarize_checkbox, api_name_input, api_key_input, import_button, import_output # # End of File -######################################################################################################################## \ No newline at end of file +######################################################################################################################## diff --git a/App_Function_Libraries/Gradio_UI/Character_Chat_tab.py b/App_Function_Libraries/Gradio_UI/Character_Chat_tab.py index e2ad29ecf..56ee24113 100644 --- a/App_Function_Libraries/Gradio_UI/Character_Chat_tab.py +++ b/App_Function_Libraries/Gradio_UI/Character_Chat_tab.py @@ -21,7 +21,7 @@ from App_Function_Libraries.Character_Chat.Character_Chat_Lib import validate_character_book, validate_v2_card, \ replace_placeholders, replace_user_placeholder, extract_json_from_image, parse_character_book, \ load_chat_and_character, load_chat_history, load_character_and_image, extract_character_id, load_character_wrapper -from App_Function_Libraries.Chat import chat +from App_Function_Libraries.Chat.Chat_Functions import chat, approximate_token_count from App_Function_Libraries.DB.Character_Chat_DB import ( add_character_card, get_character_cards, @@ -32,10 +32,10 @@ update_character_chat, delete_character_chat, delete_character_card, - update_character_card, search_character_chats, + update_character_card, search_character_chats, save_chat_history_to_character_db, ) from App_Function_Libraries.Utils.Utils import sanitize_user_input, format_api_name, global_api_endpoints, \ - default_api_endpoint + default_api_endpoint, load_comprehensive_config # @@ -267,6 +267,25 @@ def create_character_card_interaction_tab(): default_value = None with gr.TabItem("Chat with a Character Card", visible=True): gr.Markdown("# Chat with a Character Card") + with gr.Row(): + with gr.Column(scale=1): + # Checkbox to Decide Whether to Save Chats by Default + config = load_comprehensive_config() + auto_save_value = config.get('auto-save', 'save_character_chats', fallback='False') + auto_save_checkbox = gr.Checkbox(label="Save chats automatically", value=auto_save_value) + chat_media_name = gr.Textbox(label="Custom Chat Name (optional)", visible=True) + save_chat_history_to_db = gr.Button("Save Chat History to Database") + save_status = gr.Textbox(label="Status", interactive=False) + with gr.Column(scale=2): + gr.Markdown("## Search and Load Existing Chats") + chat_search_query = gr.Textbox( + label="Search Chats", + placeholder="Enter chat name or keywords to search" + ) + chat_search_button = gr.Button("Search Chats") + chat_search_dropdown = gr.Dropdown(label="Search Results", choices=[], visible=False) + load_chat_button = gr.Button("Load Selected Chat", visible=False) + with gr.Row(): with gr.Column(scale=1): character_image = gr.Image(label="Character Image", type="pil") @@ -291,24 +310,8 @@ def create_character_card_interaction_tab(): temperature_slider = gr.Slider( minimum=0.0, maximum=2.0, value=0.7, step=0.05, label="Temperature" ) - import_chat_button = gr.Button("Import Chat History") chat_file_upload = gr.File(label="Upload Chat History JSON", visible=True) - - # Chat History Import and Search - gr.Markdown("## Search and Load Existing Chats") - chat_search_query = gr.Textbox( - label="Search Chats", - placeholder="Enter chat name or keywords to search" - ) - chat_search_button = gr.Button("Search Chats") - chat_search_dropdown = gr.Dropdown(label="Search Results", choices=[], visible=False) - load_chat_button = gr.Button("Load Selected Chat", visible=False) - - # Checkbox to Decide Whether to Save Chats by Default - auto_save_checkbox = gr.Checkbox(label="Save chats automatically", value=True) - chat_media_name = gr.Textbox(label="Custom Chat Name (optional)", visible=True) - save_chat_history_to_db = gr.Button("Save Chat History to Database") - save_status = gr.Textbox(label="Status", interactive=False) + import_chat_button = gr.Button("Import Chat History") with gr.Column(scale=2): chat_history = gr.Chatbot(label="Conversation", height=800) @@ -317,6 +320,7 @@ def create_character_card_interaction_tab(): answer_for_me_button = gr.Button("Answer for Me") continue_talking_button = gr.Button("Continue Talking") regenerate_button = gr.Button("Regenerate Last Message") + token_count_display = gr.Number(label="Approximate Token Count", value=0, interactive=False) clear_chat_button = gr.Button("Clear Chat") save_snapshot_button = gr.Button("Save Chat Snapshot") update_chat_dropdown = gr.Dropdown(label="Select Chat to Update", choices=[], visible=False) @@ -501,23 +505,114 @@ def character_chat_wrapper( return history, save_status + def validate_chat_history(chat_history: List[Tuple[Optional[str], str]]) -> bool: + """ + Validate the chat history format and content. + + Args: + chat_history: List of message tuples (user_message, bot_message) + + Returns: + bool: True if valid, False if invalid + """ + if not isinstance(chat_history, list): + return False + + for entry in chat_history: + if not isinstance(entry, tuple) or len(entry) != 2: + return False + # First element can be None (for system messages) or str + if not (entry[0] is None or isinstance(entry[0], str)): + return False + # Second element (bot response) must be str and not empty + if not isinstance(entry[1], str) or not entry[1].strip(): + return False + + return True + + def sanitize_conversation_name(name: str) -> str: + """ + Sanitize the conversation name. + + Args: + name: Raw conversation name + + Returns: + str: Sanitized conversation name + """ + # Remove any non-alphanumeric characters except spaces and basic punctuation + sanitized = re.sub(r'[^a-zA-Z0-9\s\-_.]', '', name) + # Limit length + sanitized = sanitized[:100] + # Ensure it's not empty + if not sanitized.strip(): + sanitized = f"Chat_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + return sanitized + def save_chat_history_to_db_wrapper( - chat_history, conversation_id, media_content, - chat_media_name, char_data, auto_save - ): - if not char_data or not chat_history: - return "No character or chat history available.", "" + chat_history: List[Tuple[Optional[str], str]], + conversation_id: str, + media_content: Dict, + chat_media_name: str, + char_data: Dict, + auto_save: bool + ) -> Tuple[str, str]: + """ + Save chat history to the database with validation. - character_id = char_data.get('id') - if not character_id: - return "Character ID not found.", "" + Args: + chat_history: List of message tuples + conversation_id: Current conversation ID + media_content: Media content metadata + chat_media_name: Custom name for the chat + char_data: Character data dictionary + auto_save: Auto-save flag - conversation_name = chat_media_name or f"Chat {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - chat_id = add_character_chat(character_id, conversation_name, chat_history) - if chat_id: - return f"Chat saved successfully with ID {chat_id}.", "" - else: - return "Failed to save chat.", "" + Returns: + Tuple[str, str]: (status message, detail message) + """ + try: + # Basic input validation + if not chat_history: + return "No chat history to save.", "" + + if not validate_chat_history(chat_history): + return "Invalid chat history format.", "Please ensure the chat history is valid." + + if not char_data: + return "No character selected.", "Please select a character first." + + character_id = char_data.get('id') + if not character_id: + return "Invalid character data: No character ID found.", "" + + # Sanitize and prepare conversation name + conversation_name = sanitize_conversation_name( + chat_media_name if chat_media_name.strip() + else f"Chat with {char_data.get('name', 'Unknown')} - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" + ) + + # Save to the database using your existing function + chat_id = save_chat_history_to_character_db( + character_id=character_id, + conversation_name=conversation_name, + chat_history=chat_history + ) + + if chat_id: + success_message = ( + f"Chat saved successfully!\n" + f"ID: {chat_id}\n" + f"Name: {conversation_name}\n" + f"Messages: {len(chat_history)}" + ) + return success_message, "" + else: + return "Failed to save chat to database.", "Database operation failed." + + except Exception as e: + logging.error(f"Error saving chat history: {str(e)}", exc_info=True) + return f"Error saving chat: {str(e)}", "Please check the logs for more details." def update_character_info(name): return load_character_and_image(name, user_name.value) @@ -881,6 +976,10 @@ def answer_for_me( auto_save_checkbox ], outputs=[chat_history, save_status] + ).then( + lambda history: approximate_token_count(history), + inputs=[chat_history], + outputs=[token_count_display] ) continue_talking_button.click( @@ -895,6 +994,10 @@ def answer_for_me( auto_save_checkbox ], outputs=[chat_history, save_status] + ).then( + lambda history: approximate_token_count(history), + inputs=[chat_history], + outputs=[token_count_display] ) import_card_button.click( @@ -913,6 +1016,10 @@ def answer_for_me( fn=clear_chat_history, inputs=[character_data, user_name_input], outputs=[chat_history, character_data] + ).then( + lambda history: approximate_token_count(history), + inputs=[chat_history], + outputs=[token_count_display] ) character_dropdown.change( @@ -938,7 +1045,13 @@ def answer_for_me( auto_save_checkbox ], outputs=[chat_history, save_status] - ).then(lambda: "", outputs=user_input) + ).then( + lambda: "", outputs=user_input + ).then( + lambda history: approximate_token_count(history), + inputs=[chat_history], + outputs=[token_count_display] + ) regenerate_button.click( fn=regenerate_last_message, @@ -952,6 +1065,10 @@ def answer_for_me( auto_save_checkbox ], outputs=[chat_history, save_status] + ).then( + lambda history: approximate_token_count(history), + inputs=[chat_history], + outputs=[token_count_display] ) import_chat_button.click( @@ -961,8 +1078,12 @@ def answer_for_me( chat_file_upload.change( fn=import_chat_history, - inputs=[chat_file_upload, chat_history, character_data], + inputs=[chat_file_upload, chat_history, character_data, user_name_input], outputs=[chat_history, character_data, save_status] + ).then( + lambda history: approximate_token_count(history), + inputs=[chat_history], + outputs=[token_count_display] ) save_chat_history_to_db.click( @@ -1019,6 +1140,10 @@ def answer_for_me( fn=load_selected_chat_from_search, inputs=[chat_search_dropdown, user_name_input], outputs=[character_data, chat_history, character_image, save_status] + ).then( + lambda history: approximate_token_count(history), + inputs=[chat_history], + outputs=[token_count_display] ) # Show Load Chat Button when a chat is selected @@ -1033,8 +1158,8 @@ def answer_for_me( def create_character_chat_mgmt_tab(): - with gr.TabItem("Character and Chat Management", visible=True): - gr.Markdown("# Character and Chat Management") + with gr.TabItem("Character Chat Management", visible=True): + gr.Markdown("# Character Chat Management") with gr.Row(): # Left Column: Character Import and Chat Management diff --git a/App_Function_Libraries/Gradio_UI/Character_interaction_tab.py b/App_Function_Libraries/Gradio_UI/Character_interaction_tab.py index 645873274..a9de2297a 100644 --- a/App_Function_Libraries/Gradio_UI/Character_interaction_tab.py +++ b/App_Function_Libraries/Gradio_UI/Character_interaction_tab.py @@ -17,7 +17,7 @@ from PIL import Image # # Local Imports -from App_Function_Libraries.Chat import chat, load_characters, save_chat_history_to_db_wrapper +from App_Function_Libraries.Chat.Chat_Functions import chat, load_characters, save_chat_history_to_db_wrapper from App_Function_Libraries.Gradio_UI.Chat_ui import chat_wrapper from App_Function_Libraries.Gradio_UI.Writing_tab import generate_writing_feedback from App_Function_Libraries.Utils.Utils import default_api_endpoint, format_api_name, global_api_endpoints diff --git a/App_Function_Libraries/Gradio_UI/Chat_ui.py b/App_Function_Libraries/Gradio_UI/Chat_ui.py index e1410a0bb..6df00168d 100644 --- a/App_Function_Libraries/Gradio_UI/Chat_ui.py +++ b/App_Function_Libraries/Gradio_UI/Chat_ui.py @@ -2,24 +2,25 @@ # Description: Chat interface functions for Gradio # # Imports -import html -import json import logging import os import sqlite3 +import time from datetime import datetime # # External Imports import gradio as gr # # Local Imports -from App_Function_Libraries.Chat import chat, save_chat_history, update_chat_content, save_chat_history_to_db_wrapper -from App_Function_Libraries.DB.DB_Manager import add_chat_message, search_chat_conversations, create_chat_conversation, \ - get_chat_messages, update_chat_message, delete_chat_message, load_preset_prompts, db +from App_Function_Libraries.Chat.Chat_Functions import approximate_token_count, chat, save_chat_history, \ + update_chat_content, save_chat_history_to_db_wrapper +from App_Function_Libraries.DB.DB_Manager import db, load_chat_history, start_new_conversation, \ + save_message, search_conversations_by_keywords, \ + get_all_conversations, delete_messages_in_conversation, search_media_db, list_prompts +from App_Function_Libraries.DB.RAG_QA_Chat_DB import get_db_connection from App_Function_Libraries.Gradio_UI.Gradio_Shared import update_dropdown, update_user_prompt +from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram from App_Function_Libraries.Utils.Utils import default_api_endpoint, format_api_name, global_api_endpoints - - # # ######################################################################################################################## @@ -92,10 +93,9 @@ def chat_wrapper(message, history, media_content, selected_parts, api_endpoint, # Create a new conversation media_id = media_content.get('id', None) conversation_name = f"Chat about {media_content.get('title', 'Unknown Media')} - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - conversation_id = create_chat_conversation(media_id, conversation_name) - + conversation_id = start_new_conversation(title=conversation_name, media_id=media_id) # Add user message to the database - user_message_id = add_chat_message(conversation_id, "user", message) + user_message_id = save_message(conversation_id, role="user", content=message) # Include the selected parts and custom_prompt only for the first message if not history and selected_parts: @@ -114,7 +114,7 @@ def chat_wrapper(message, history, media_content, selected_parts, api_endpoint, if save_conversation: # Add assistant message to the database - add_chat_message(conversation_id, "assistant", bot_message) + save_message(conversation_id, role="assistant", content=bot_message) # Update history new_history = history + [(message, bot_message)] @@ -124,51 +124,57 @@ def chat_wrapper(message, history, media_content, selected_parts, api_endpoint, logging.error(f"Error in chat wrapper: {str(e)}") return "An error occurred.", history, conversation_id + def search_conversations(query): + """Convert existing chat search to use RAG chat functions""" try: - conversations = search_chat_conversations(query) - if not conversations: - print(f"Debug - Search Conversations - No results found for query: {query}") + # Use the RAG search function - search by title if given a query + if query and query.strip(): + results, _, _ = search_conversations_by_keywords( + title_query=query.strip() + ) + else: + # Get all conversations if no query + results, _, _ = get_all_conversations() + + if not results: return gr.update(choices=[]) + # Format choices to match existing UI format conversation_options = [ - (f"{c['conversation_name']} (Media: {c['media_title']}, ID: {c['id']})", c['id']) - for c in conversations + (f"{conv['title']} (ID: {conv['conversation_id'][:8]})", conv['conversation_id']) + for conv in results ] - print(f"Debug - Search Conversations - Options: {conversation_options}") + return gr.update(choices=conversation_options) except Exception as e: - print(f"Debug - Search Conversations - Error: {str(e)}") + logging.error(f"Error searching conversations: {str(e)}") return gr.update(choices=[]) def load_conversation(conversation_id): + """Convert existing load to use RAG chat functions""" if not conversation_id: return [], None - messages = get_chat_messages(conversation_id) - history = [ - (msg['message'], None) if msg['sender'] == 'user' else (None, msg['message']) - for msg in messages - ] - return history, conversation_id - - -def update_message_in_chat(message_id, new_text, history): - update_chat_message(message_id, new_text) - updated_history = [(msg1, msg2) if msg1[1] != message_id and msg2[1] != message_id - else ((new_text, msg1[1]) if msg1[1] == message_id else (new_text, msg2[1])) - for msg1, msg2 in history] - return updated_history + try: + # Use RAG load function + messages, _, _ = load_chat_history(conversation_id) + # Convert to chatbot history format + history = [ + (content, None) if role == 'user' else (None, content) + for role, content in messages + ] -def delete_message_from_chat(message_id, history): - delete_chat_message(message_id) - updated_history = [(msg1, msg2) for msg1, msg2 in history if msg1[1] != message_id and msg2[1] != message_id] - return updated_history + return history, conversation_id + except Exception as e: + logging.error(f"Error loading conversation: {str(e)}") + return [], None -def regenerate_last_message(history, media_content, selected_parts, api_endpoint, api_key, custom_prompt, temperature, system_prompt): +def regenerate_last_message(history, media_content, selected_parts, api_endpoint, api_key, custom_prompt, temperature, + system_prompt): if not history: return history, "No messages to regenerate." @@ -201,6 +207,45 @@ def regenerate_last_message(history, media_content, selected_parts, api_endpoint return new_history, "Last message regenerated successfully." + +def update_dropdown_multiple(query, search_type, keywords=""): + """Updated function to handle multiple search results using search_media_db""" + try: + # Define search fields based on search type + search_fields = [] + if search_type.lower() == "keyword": + # When searching by keyword, we'll search across multiple fields + search_fields = ["title", "content", "author"] + else: + # Otherwise use the specific field + search_fields = [search_type.lower()] + + # Perform the search + results = search_media_db( + search_query=query, + search_fields=search_fields, + keywords=keywords, + page=1, + results_per_page=50 # Adjust as needed + ) + + # Process results + item_map = {} + formatted_results = [] + + for row in results: + id, url, title, type_, content, author, date, prompt, summary = row + # Create a display text that shows relevant info + display_text = f"{title} - {author or 'Unknown'} ({date})" + formatted_results.append(display_text) + item_map[display_text] = id + + return gr.update(choices=formatted_results), item_map + except Exception as e: + logging.error(f"Error in update_dropdown_multiple: {str(e)}") + return gr.update(choices=[]), {} + + def create_chat_interface(): try: default_value = None @@ -226,9 +271,19 @@ def create_chat_interface(): with gr.Row(): with gr.Column(scale=1): - search_query_input = gr.Textbox(label="Search Query", placeholder="Enter your search query here...") - search_type_input = gr.Radio(choices=["Title", "URL", "Keyword", "Content"], value="Title", - label="Search By") + search_query_input = gr.Textbox( + label="Search Query", + placeholder="Enter your search query here..." + ) + search_type_input = gr.Radio( + choices=["Title", "Content", "Author", "Keyword"], + value="Keyword", + label="Search By" + ) + keyword_filter_input = gr.Textbox( + label="Filter by Keywords (comma-separated)", + placeholder="ml, ai, python, etc..." + ) search_button = gr.Button("Search") items_output = gr.Dropdown(label="Select Item", choices=[], interactive=True) item_mapping = gr.State({}) @@ -255,47 +310,53 @@ def create_chat_interface(): label="API for Chat Interaction (Optional)" ) api_key = gr.Textbox(label="API Key (if required)", type="password") + + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + custom_prompt_checkbox = gr.Checkbox(label="Use a Custom Prompt", value=False, visible=True) preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", value=False, visible=True) - preset_prompt = gr.Dropdown(label="Select Preset Prompt", - choices=load_preset_prompts(), - visible=False) - user_prompt = gr.Textbox(label="Custom Prompt", - placeholder="Enter custom prompt here", - lines=3, - visible=False) - system_prompt_input = gr.Textbox(label="System Prompt", - value="You are a helpful AI assitant", - lines=3, - visible=False) + with gr.Row(): + # Add pagination controls + preset_prompt = gr.Dropdown(label="Select Preset Prompt", + choices=[], + visible=False) + with gr.Row(): + prev_page_button = gr.Button("Previous Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + next_page_button = gr.Button("Next Page", visible=False) + system_prompt_input = gr.Textbox(label="System Prompt", + value="You are a helpful AI assistant", + lines=3, + visible=False) + with gr.Row(): + user_prompt = gr.Textbox(label="Custom Prompt", + placeholder="Enter custom prompt here", + lines=3, + visible=False) with gr.Column(scale=2): chatbot = gr.Chatbot(height=800, elem_classes="chatbot-container") msg = gr.Textbox(label="Enter your message") submit = gr.Button("Submit") regenerate_button = gr.Button("Regenerate Last Message") + token_count_display = gr.Number(label="Approximate Token Count", value=0, interactive=False) clear_chat_button = gr.Button("Clear Chat") - edit_message_id = gr.Number(label="Message ID to Edit", visible=False) - edit_message_text = gr.Textbox(label="Edit Message", visible=False) - update_message_button = gr.Button("Update Message", visible=False) - - delete_message_id = gr.Number(label="Message ID to Delete", visible=False) - delete_message_button = gr.Button("Delete Message", visible=False) - chat_media_name = gr.Textbox(label="Custom Chat Name(optional)") save_chat_history_to_db = gr.Button("Save Chat History to DataBase") + save_status = gr.Textbox(label="Save Status", interactive=False) save_chat_history_as_file = gr.Button("Save Chat History as File") download_file = gr.File(label="Download Chat History") - save_status = gr.Textbox(label="Save Status", interactive=False) # Restore original functionality search_button.click( - fn=update_dropdown, - inputs=[search_query_input, search_type_input], + fn=update_dropdown_multiple, + inputs=[search_query_input, search_type_input, keyword_filter_input], outputs=[items_output, item_mapping] ) @@ -326,21 +387,72 @@ def clear_chat(): clear_chat, outputs=[chatbot, conversation_id] ) + + # Function to handle preset prompt checkbox change + def on_preset_prompt_checkbox_change(is_checked): + if is_checked: + prompts, total_pages, current_page = list_prompts(page=1, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(visible=True, interactive=True, choices=prompts), # preset_prompt + gr.update(visible=True), # prev_page_button + gr.update(visible=True), # next_page_button + gr.update(value=page_display_text, visible=True), # page_display + current_page, # current_page_state + total_pages # total_pages_state + ) + else: + return ( + gr.update(visible=False, interactive=False), # preset_prompt + gr.update(visible=False), # prev_page_button + gr.update(visible=False), # next_page_button + gr.update(visible=False), # page_display + 1, # current_page_state + 1 # total_pages_state + ) + + preset_prompt_checkbox.change( + fn=on_preset_prompt_checkbox_change, + inputs=[preset_prompt_checkbox], + outputs=[preset_prompt, prev_page_button, next_page_button, page_display, current_page_state, total_pages_state] + ) + + def on_prev_page_click(current_page, total_pages): + new_page = max(current_page - 1, 1) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + preset_prompt.change( update_prompts, - inputs=preset_prompt, + inputs=[preset_prompt], outputs=[user_prompt, system_prompt_input] ) + custom_prompt_checkbox.change( fn=lambda x: (gr.update(visible=x), gr.update(visible=x)), inputs=[custom_prompt_checkbox], outputs=[user_prompt, system_prompt_input] ) - preset_prompt_checkbox.change( - fn=lambda x: gr.update(visible=x), - inputs=[preset_prompt_checkbox], - outputs=[preset_prompt] - ) + submit.click( chat_wrapper, inputs=[msg, chatbot, media_content, selected_parts, api_endpoint, api_key, user_prompt, conversation_id, @@ -353,6 +465,10 @@ def clear_chat(): ).then( # Clear the user prompt after the first message lambda: (gr.update(value=""), gr.update(value="")), outputs=[user_prompt, system_prompt_input] + ).then( + lambda history: approximate_token_count(history), + inputs=[chatbot], + outputs=[token_count_display] ) items_output.change( @@ -360,6 +476,7 @@ def clear_chat(): inputs=[items_output, use_content, use_summary, use_prompt, item_mapping], outputs=[media_content, selected_parts] ) + use_content.change(update_selected_parts, inputs=[use_content, use_summary, use_prompt], outputs=[selected_parts]) use_summary.change(update_selected_parts, inputs=[use_content, use_summary, use_prompt], @@ -389,18 +506,6 @@ def clear_chat(): outputs=[chat_history] ) - update_message_button.click( - update_message_in_chat, - inputs=[edit_message_id, edit_message_text, chat_history], - outputs=[chatbot] - ) - - delete_message_button.click( - delete_message_from_chat, - inputs=[delete_message_id, chat_history], - outputs=[chatbot] - ) - save_chat_history_as_file.click( save_chat_history, inputs=[chatbot, conversation_id], @@ -415,13 +520,15 @@ def clear_chat(): regenerate_button.click( regenerate_last_message, - inputs=[chatbot, media_content, selected_parts, api_endpoint, api_key, user_prompt, temperature, system_prompt_input], + inputs=[chatbot, media_content, selected_parts, api_endpoint, api_key, user_prompt, temperature, + system_prompt_input], outputs=[chatbot, save_status] + ).then( + lambda history: approximate_token_count(history), + inputs=[chatbot], + outputs=[token_count_display] ) - chatbot.select(show_edit_message, None, [edit_message_text, edit_message_id, update_message_button]) - chatbot.select(show_delete_message, None, [delete_message_id, delete_message_button]) - def create_chat_interface_stacked(): try: @@ -434,6 +541,7 @@ def create_chat_interface_stacked(): except Exception as e: logging.error(f"Error setting default API endpoint: {str(e)}") default_value = None + custom_css = """ .chatbot-container .message-wrap .message { font-size: 14px !important; @@ -448,9 +556,19 @@ def create_chat_interface_stacked(): with gr.Row(): with gr.Column(): - search_query_input = gr.Textbox(label="Search Query", placeholder="Enter your search query here...") - search_type_input = gr.Radio(choices=["Title", "URL", "Keyword", "Content"], value="Title", - label="Search By") + search_query_input = gr.Textbox( + label="Search Query", + placeholder="Enter your search query here..." + ) + search_type_input = gr.Radio( + choices=["Title", "Content", "Author", "Keyword"], + value="Keyword", + label="Search By" + ) + keyword_filter_input = gr.Textbox( + label="Filter by Keywords (comma-separated)", + placeholder="ml, ai, python, etc..." + ) search_button = gr.Button("Search") items_output = gr.Dropdown(label="Select Item", choices=[], interactive=True) item_mapping = gr.State({}) @@ -475,17 +593,45 @@ def create_chat_interface_stacked(): label="API for Chat Interaction (Optional)" ) api_key = gr.Textbox(label="API Key (if required)", type="password") - preset_prompt = gr.Dropdown(label="Select Preset Prompt", - choices=load_preset_prompts(), - visible=True) - system_prompt = gr.Textbox(label="System Prompt", - value="You are a helpful AI assistant.", - lines=3, - visible=True) - user_prompt = gr.Textbox(label="Custom User Prompt", - placeholder="Enter custom prompt here", - lines=3, - visible=True) + + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + + custom_prompt_checkbox = gr.Checkbox( + label="Use a Custom Prompt", + value=False, + visible=True + ) + preset_prompt_checkbox = gr.Checkbox( + label="Use a pre-set Prompt", + value=False, + visible=True + ) + + with gr.Row(): + preset_prompt = gr.Dropdown( + label="Select Preset Prompt", + choices=[], + visible=False + ) + with gr.Row(): + prev_page_button = gr.Button("Previous Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + next_page_button = gr.Button("Next Page", visible=False) + + system_prompt = gr.Textbox( + label="System Prompt", + value="You are a helpful AI assistant.", + lines=4, + visible=False + ) + user_prompt = gr.Textbox( + label="Custom User Prompt", + placeholder="Enter custom prompt here", + lines=4, + visible=False + ) gr.Markdown("Scroll down for the chat window...") with gr.Row(): with gr.Column(scale=1): @@ -495,20 +641,110 @@ def create_chat_interface_stacked(): with gr.Column(): submit = gr.Button("Submit") regenerate_button = gr.Button("Regenerate Last Message") + token_count_display = gr.Number(label="Approximate Token Count", value=0, interactive=False) clear_chat_button = gr.Button("Clear Chat") chat_media_name = gr.Textbox(label="Custom Chat Name(optional)", visible=True) save_chat_history_to_db = gr.Button("Save Chat History to DataBase") + save_status = gr.Textbox(label="Save Status", interactive=False) save_chat_history_as_file = gr.Button("Save Chat History as File") with gr.Column(): download_file = gr.File(label="Download Chat History") # Restore original functionality search_button.click( - fn=update_dropdown, - inputs=[search_query_input, search_type_input], + fn=update_dropdown_multiple, + inputs=[search_query_input, search_type_input, keyword_filter_input], outputs=[items_output, item_mapping] ) + def search_conversations(query): + try: + # Use RAG search with title search + if query and query.strip(): + results, _, _ = search_conversations_by_keywords(title_query=query.strip()) + else: + results, _, _ = get_all_conversations() + + if not results: + return gr.update(choices=[]) + + # Format choices to match UI + conversation_options = [ + (f"{conv['title']} (ID: {conv['conversation_id'][:8]})", conv['conversation_id']) + for conv in results + ] + + return gr.update(choices=conversation_options) + except Exception as e: + logging.error(f"Error searching conversations: {str(e)}") + return gr.update(choices=[]) + + def load_conversation(conversation_id): + if not conversation_id: + return [], None + + try: + # Use RAG load function + messages, _, _ = load_chat_history(conversation_id) + + # Convert to chatbot history format + history = [ + (content, None) if role == 'user' else (None, content) + for role, content in messages + ] + + return history, conversation_id + except Exception as e: + logging.error(f"Error loading conversation: {str(e)}") + return [], None + + def save_chat_history_to_db_wrapper(chatbot, conversation_id, media_content, chat_name=None): + log_counter("save_chat_history_to_db_attempt") + start_time = time.time() + logging.info(f"Attempting to save chat history. Media content type: {type(media_content)}") + + try: + # First check if we can access the database + try: + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT 1") + except sqlite3.DatabaseError as db_error: + logging.error(f"Database is corrupted or inaccessible: {str(db_error)}") + return conversation_id, gr.update( + value="Database error: The database file appears to be corrupted. Please contact support.") + + # For both new and existing conversations + try: + if not conversation_id: + title = chat_name if chat_name else "Untitled Conversation" + conversation_id = start_new_conversation(title=title) + logging.info(f"Created new conversation with ID: {conversation_id}") + + # Update existing messages + delete_messages_in_conversation(conversation_id) + for user_msg, assistant_msg in chatbot: + if user_msg: + save_message(conversation_id, "user", user_msg) + if assistant_msg: + save_message(conversation_id, "assistant", assistant_msg) + except sqlite3.DatabaseError as db_error: + logging.error(f"Database error during message save: {str(db_error)}") + return conversation_id, gr.update( + value="Database error: Unable to save messages. Please try again or contact support.") + + save_duration = time.time() - start_time + log_histogram("save_chat_history_to_db_duration", save_duration) + log_counter("save_chat_history_to_db_success") + + return conversation_id, gr.update(value="Chat history saved successfully!") + + except Exception as e: + log_counter("save_chat_history_to_db_error", labels={"error": str(e)}) + error_message = f"Failed to save chat history: {str(e)}" + logging.error(error_message, exc_info=True) + return conversation_id, gr.update(value=error_message) + def update_prompts(preset_name): prompts = update_user_prompt(preset_name) return ( @@ -516,13 +752,85 @@ def update_prompts(preset_name): gr.update(value=prompts["system_prompt"], visible=True) ) + def clear_chat(): + return [], None, 0 # Empty history, conversation_id, and token count + clear_chat_button.click( clear_chat, - outputs=[chatbot, conversation_id] + outputs=[chatbot, conversation_id, token_count_display] + ) + + # Handle custom prompt checkbox change + def on_custom_prompt_checkbox_change(is_checked): + return ( + gr.update(visible=is_checked), + gr.update(visible=is_checked) + ) + + custom_prompt_checkbox.change( + fn=on_custom_prompt_checkbox_change, + inputs=[custom_prompt_checkbox], + outputs=[user_prompt, system_prompt] + ) + + # Handle preset prompt checkbox change + def on_preset_prompt_checkbox_change(is_checked): + if is_checked: + prompts, total_pages, current_page = list_prompts(page=1, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(visible=True, interactive=True, choices=prompts), # preset_prompt + gr.update(visible=True), # prev_page_button + gr.update(visible=True), # next_page_button + gr.update(value=page_display_text, visible=True), # page_display + current_page, # current_page_state + total_pages # total_pages_state + ) + else: + return ( + gr.update(visible=False, interactive=False), # preset_prompt + gr.update(visible=False), # prev_page_button + gr.update(visible=False), # next_page_button + gr.update(visible=False), # page_display + 1, # current_page_state + 1 # total_pages_state + ) + + preset_prompt_checkbox.change( + fn=on_preset_prompt_checkbox_change, + inputs=[preset_prompt_checkbox], + outputs=[preset_prompt, prev_page_button, next_page_button, page_display, current_page_state, total_pages_state] + ) + + # Pagination button functions + def on_prev_page_click(current_page, total_pages): + new_page = max(current_page - 1, 1) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] ) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + # Update prompts when a preset is selected preset_prompt.change( update_prompts, - inputs=preset_prompt, + inputs=[preset_prompt], outputs=[user_prompt, system_prompt] ) @@ -531,13 +839,14 @@ def update_prompts(preset_name): inputs=[msg, chatbot, media_content, selected_parts, api_endpoint, api_key, user_prompt, conversation_id, save_conversation, temp, system_prompt], outputs=[msg, chatbot, conversation_id] - ).then( # Clear the message box after submission + ).then( lambda x: gr.update(value=""), inputs=[chatbot], outputs=[msg] - ).then( # Clear the user prompt after the first message - lambda: gr.update(value=""), - outputs=[user_prompt, system_prompt] + ).then( + lambda history: approximate_token_count(history), + inputs=[chatbot], + outputs=[token_count_display] ) items_output.change( @@ -583,17 +892,20 @@ def update_prompts(preset_name): save_chat_history_to_db.click( save_chat_history_to_db_wrapper, inputs=[chatbot, conversation_id, media_content, chat_media_name], - outputs=[conversation_id, gr.Textbox(label="Save Status")] + outputs=[conversation_id, save_status] ) regenerate_button.click( regenerate_last_message, inputs=[chatbot, media_content, selected_parts, api_endpoint, api_key, user_prompt, temp, system_prompt], outputs=[chatbot, gr.Textbox(label="Regenerate Status")] + ).then( + lambda history: approximate_token_count(history), + inputs=[chatbot], + outputs=[token_count_display] ) -# FIXME - System prompts def create_chat_interface_multi_api(): try: default_value = None @@ -630,9 +942,31 @@ def create_chat_interface_multi_api(): use_summary = gr.Checkbox(label="Use Summary") use_prompt = gr.Checkbox(label="Use Prompt") with gr.Column(): - preset_prompt = gr.Dropdown(label="Select Preset Prompt", choices=load_preset_prompts(), visible=True) - system_prompt = gr.Textbox(label="System Prompt", value="You are a helpful AI assistant.", lines=5) - user_prompt = gr.Textbox(label="Modify Prompt (Prefixed to your message every time)", lines=5, value="", visible=True) + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + + custom_prompt_checkbox = gr.Checkbox(label="Use a Custom Prompt", + value=False, + visible=True) + preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", + value=False, + visible=True) + with gr.Row(): + # Add pagination controls + preset_prompt = gr.Dropdown(label="Select Preset Prompt", + choices=[], + visible=False) + with gr.Row(): + prev_page_button = gr.Button("Previous Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + next_page_button = gr.Button("Next Page", visible=False) + system_prompt = gr.Textbox(label="System Prompt", + value="You are a helpful AI assistant.", + lines=5, + visible=True) + user_prompt = gr.Textbox(label="Modify Prompt (Prefixed to your message every time)", lines=5, + value="", visible=True) with gr.Row(): chatbots = [] @@ -640,6 +974,7 @@ def create_chat_interface_multi_api(): api_keys = [] temperatures = [] regenerate_buttons = [] + token_count_displays = [] for i in range(3): with gr.Column(): gr.Markdown(f"### Chat Window {i + 1}") @@ -653,6 +988,9 @@ def create_chat_interface_multi_api(): temperature = gr.Slider(label=f"Temperature {i + 1}", minimum=0.0, maximum=1.0, step=0.05, value=0.7) chatbot = gr.Chatbot(height=800, elem_classes="chat-window") + token_count_display = gr.Number(label=f"Approximate Token Count {i + 1}", value=0, + interactive=False) + token_count_displays.append(token_count_display) regenerate_button = gr.Button(f"Regenerate Last Message {i + 1}") chatbots.append(chatbot) api_endpoints.append(api_endpoint) @@ -678,16 +1016,103 @@ def create_chat_interface_multi_api(): outputs=[items_output, item_mapping] ) + def update_prompts(preset_name): + prompts = update_user_prompt(preset_name) + return ( + gr.update(value=prompts["user_prompt"], visible=True), + gr.update(value=prompts["system_prompt"], visible=True) + ) + + def on_custom_prompt_checkbox_change(is_checked): + return ( + gr.update(visible=is_checked), + gr.update(visible=is_checked) + ) + + custom_prompt_checkbox.change( + fn=on_custom_prompt_checkbox_change, + inputs=[custom_prompt_checkbox], + outputs=[user_prompt, system_prompt] + ) + + def clear_all_chats(): + return [[]] * 3 + [[]] * 3 + [0] * 3 + + clear_chat_button.click( + clear_all_chats, + outputs=chatbots + chat_history + token_count_displays + ) + + def on_preset_prompt_checkbox_change(is_checked): + if is_checked: + prompts, total_pages, current_page = list_prompts(page=1, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(visible=True, interactive=True, choices=prompts), # preset_prompt + gr.update(visible=True), # prev_page_button + gr.update(visible=True), # next_page_button + gr.update(value=page_display_text, visible=True), # page_display + current_page, # current_page_state + total_pages # total_pages_state + ) + else: + return ( + gr.update(visible=False, interactive=False), # preset_prompt + gr.update(visible=False), # prev_page_button + gr.update(visible=False), # next_page_button + gr.update(visible=False), # page_display + 1, # current_page_state + 1 # total_pages_state + ) + preset_prompt.change(update_user_prompt, inputs=preset_prompt, outputs=user_prompt) + preset_prompt_checkbox.change( + fn=on_preset_prompt_checkbox_change, + inputs=[preset_prompt_checkbox], + outputs=[preset_prompt, prev_page_button, next_page_button, page_display, current_page_state, + total_pages_state] + ) + + def on_prev_page_click(current_page, total_pages): + new_page = max(current_page - 1, 1) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + # Update prompts when a preset is selected + preset_prompt.change( + update_prompts, + inputs=[preset_prompt], + outputs=[user_prompt, system_prompt] + ) def clear_all_chats(): - return [[]] * 3 + [[]] * 3 + return [[]] * 3 + [[]] * 3 + [0] * 3 clear_chat_button.click( clear_all_chats, - outputs=chatbots + chat_history + outputs=chatbots + chat_history + token_count_displays ) + def chat_wrapper_multi(message, custom_prompt, system_prompt, *args): chat_histories = args[:3] chatbots = args[3:6] @@ -717,6 +1142,11 @@ def chat_wrapper_multi(message, custom_prompt, system_prompt, *args): return [gr.update(value="")] + new_chatbots + new_chat_histories + def update_token_counts(*histories): + token_counts = [] + for history in histories: + token_counts.append(approximate_token_count(history)) + return token_counts def regenerate_last_message(chat_history, chatbot, media_content, selected_parts, api_endpoint, api_key, custom_prompt, temperature, system_prompt): if not chat_history: @@ -753,8 +1183,13 @@ def regenerate_last_message(chat_history, chatbot, media_content, selected_parts for i in range(3): regenerate_buttons[i].click( regenerate_last_message, - inputs=[chat_history[i], chatbots[i], media_content, selected_parts, api_endpoints[i], api_keys[i], user_prompt, temperatures[i], system_prompt], + inputs=[chat_history[i], chatbots[i], media_content, selected_parts, api_endpoints[i], api_keys[i], + user_prompt, temperatures[i], system_prompt], outputs=[chatbots[i], chat_history[i], gr.Textbox(label=f"Regenerate Status {i + 1}")] + ).then( + lambda history: approximate_token_count(history), + inputs=[chat_history[i]], + outputs=[token_count_displays[i]] ) # In the create_chat_interface_multi_api function: @@ -767,6 +1202,10 @@ def regenerate_last_message(chat_history, chatbot, media_content, selected_parts ).then( lambda: (gr.update(value=""), gr.update(value="")), outputs=[msg, user_prompt] + ).then( + update_token_counts, + inputs=chat_history, + outputs=token_count_displays ) items_output.change( @@ -783,7 +1222,6 @@ def regenerate_last_message(chat_history, chatbot, media_content, selected_parts ) - def create_chat_interface_four(): try: default_value = None @@ -808,17 +1246,32 @@ def create_chat_interface_four(): with gr.TabItem("Four Independent API Chats", visible=True): gr.Markdown("# Four Independent API Chat Interfaces") + # Initialize prompts during component creation + prompts, total_pages, current_page = list_prompts(page=1, per_page=10) + current_page_state = gr.State(value=current_page) + total_pages_state = gr.State(value=total_pages) + page_display_text = f"Page {current_page} of {total_pages}" + with gr.Row(): with gr.Column(): preset_prompt = gr.Dropdown( - label="Select Preset Prompt", - choices=load_preset_prompts(), + label="Select Preset Prompt (This will be prefixed to your messages, recommend copy/pasting and then clearing the User Prompt box)", + choices=prompts, visible=True ) + prev_page_button = gr.Button("Previous Page", visible=True) + page_display = gr.Markdown(page_display_text, visible=True) + next_page_button = gr.Button("Next Page", visible=True) user_prompt = gr.Textbox( - label="Modify Prompt", + label="Modify User Prompt", + lines=3 + ) + system_prompt = gr.Textbox( + label="System Prompt", + value="You are a helpful AI assistant.", lines=3 ) + with gr.Column(): gr.Markdown("Scroll down for the chat windows...") @@ -848,6 +1301,8 @@ def create_single_chat_interface(index, user_prompt_component): msg = gr.Textbox(label=f"Enter your message for Chat {index + 1}") submit = gr.Button(f"Submit to Chat {index + 1}") regenerate_button = gr.Button(f"Regenerate Last Message {index + 1}") + token_count_display = gr.Number(label=f"Approximate Token Count {index + 1}", value=0, + interactive=False) clear_chat_button = gr.Button(f"Clear Chat {index + 1}") # State to maintain chat history @@ -863,7 +1318,8 @@ def create_single_chat_interface(index, user_prompt_component): 'submit': submit, 'regenerate_button': regenerate_button, 'clear_chat_button': clear_chat_button, - 'chat_history': chat_history + 'chat_history': chat_history, + 'token_count_display': token_count_display }) # Create four chat interfaces arranged in a 2x2 grid @@ -874,10 +1330,47 @@ def create_single_chat_interface(index, user_prompt_component): create_single_chat_interface(i * 2 + j, user_prompt) # Update user_prompt based on preset_prompt selection + def update_prompts(preset_name): + prompts = update_user_prompt(preset_name) + return gr.update(value=prompts["user_prompt"]), gr.update(value=prompts["system_prompt"]) + preset_prompt.change( - fn=update_user_prompt, - inputs=preset_prompt, - outputs=user_prompt + fn=update_prompts, + inputs=[preset_prompt], + outputs=[user_prompt, system_prompt] + ) + + # Pagination button functions + def on_prev_page_click(current_page, total_pages): + new_page = max(current_page - 1, 1) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text), + current_page + ) + + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text), + current_page + ) + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] ) def chat_wrapper_single(message, chat_history, api_endpoint, api_key, temperature, user_prompt): @@ -957,6 +1450,10 @@ def regenerate_last_message(chat_history, api_endpoint, api_key, temperature, us interface['chatbot'], interface['chat_history'] ] + ).then( + lambda history: approximate_token_count(history), + inputs=[interface['chat_history']], + outputs=[interface['token_count_display']] ) interface['regenerate_button'].click( @@ -973,12 +1470,18 @@ def regenerate_last_message(chat_history, api_endpoint, api_key, temperature, us interface['chat_history'], gr.Textbox(label="Regenerate Status") ] + ).then( + lambda history: approximate_token_count(history), + inputs=[interface['chat_history']], + outputs=[interface['token_count_display']] ) + def clear_chat_single(): + return [], [], 0 + interface['clear_chat_button'].click( clear_chat_single, - inputs=[], - outputs=[interface['chatbot'], interface['chat_history']] + outputs=[interface['chatbot'], interface['chat_history'], interface['token_count_display']] ) @@ -997,233 +1500,11 @@ def chat_wrapper_single(message, chat_history, chatbot, api_endpoint, api_key, t return new_msg, updated_chatbot, new_history, new_conv_id - -# FIXME - Finish implementing functions + testing/valdidation -def create_chat_management_tab(): - with gr.TabItem("Chat Management", visible=True): - gr.Markdown("# Chat Management") - - with gr.Row(): - search_query = gr.Textbox(label="Search Conversations") - search_button = gr.Button("Search") - - conversation_list = gr.Dropdown(label="Select Conversation", choices=[]) - conversation_mapping = gr.State({}) - - with gr.Tabs(): - with gr.TabItem("Edit", visible=True): - chat_content = gr.TextArea(label="Chat Content (JSON)", lines=20, max_lines=50) - save_button = gr.Button("Save Changes") - delete_button = gr.Button("Delete Conversation", variant="stop") - - with gr.TabItem("Preview", visible=True): - chat_preview = gr.HTML(label="Chat Preview") - result_message = gr.Markdown("") - - def search_conversations(query): - conversations = search_chat_conversations(query) - choices = [f"{conv['conversation_name']} (Media: {conv['media_title']}, ID: {conv['id']})" for conv in - conversations] - mapping = {choice: conv['id'] for choice, conv in zip(choices, conversations)} - return gr.update(choices=choices), mapping - - def load_conversations(selected, conversation_mapping): - logging.info(f"Selected: {selected}") - logging.info(f"Conversation mapping: {conversation_mapping}") - - try: - if selected and selected in conversation_mapping: - conversation_id = conversation_mapping[selected] - messages = get_chat_messages(conversation_id) - conversation_data = { - "conversation_id": conversation_id, - "messages": messages - } - json_content = json.dumps(conversation_data, indent=2) - - # Create HTML preview - html_preview = "
" - for msg in messages: - sender_style = "background-color: #e6f3ff;" if msg[ - 'sender'] == 'user' else "background-color: #f0f0f0;" - html_preview += f"
" - html_preview += f"{msg['sender']}: {html.escape(msg['message'])}
" - html_preview += f"Timestamp: {msg['timestamp']}" - html_preview += "
" - html_preview += "
" - - logging.info("Returning json_content and html_preview") - return json_content, html_preview - else: - logging.warning("No conversation selected or not in mapping") - return "", "

No conversation selected

" - except Exception as e: - logging.error(f"Error in load_conversations: {str(e)}") - return f"Error: {str(e)}", "

Error loading conversation

" - - def validate_conversation_json(content): - try: - data = json.loads(content) - if not isinstance(data, dict): - return False, "Invalid JSON structure: root should be an object" - if "conversation_id" not in data or not isinstance(data["conversation_id"], int): - return False, "Missing or invalid conversation_id" - if "messages" not in data or not isinstance(data["messages"], list): - return False, "Missing or invalid messages array" - for msg in data["messages"]: - if not all(key in msg for key in ["sender", "message"]): - return False, "Invalid message structure: missing required fields" - return True, data - except json.JSONDecodeError as e: - return False, f"Invalid JSON: {str(e)}" - - def save_conversation(selected, conversation_mapping, content): - if not selected or selected not in conversation_mapping: - return "Please select a conversation before saving.", "

No changes made

" - - conversation_id = conversation_mapping[selected] - is_valid, result = validate_conversation_json(content) - - if not is_valid: - return f"Error: {result}", "

No changes made due to error

" - - conversation_data = result - if conversation_data["conversation_id"] != conversation_id: - return "Error: Conversation ID mismatch.", "

No changes made due to ID mismatch

" - - try: - with db.get_connection() as conn: - conn.execute("BEGIN TRANSACTION") - cursor = conn.cursor() - - # Backup original conversation - cursor.execute("SELECT * FROM ChatMessages WHERE conversation_id = ?", (conversation_id,)) - original_messages = cursor.fetchall() - backup_data = json.dumps({"conversation_id": conversation_id, "messages": original_messages}) - - # You might want to save this backup_data somewhere - - # Delete existing messages - cursor.execute("DELETE FROM ChatMessages WHERE conversation_id = ?", (conversation_id,)) - - # Insert updated messages - for message in conversation_data["messages"]: - cursor.execute(''' - INSERT INTO ChatMessages (conversation_id, sender, message, timestamp) - VALUES (?, ?, ?, COALESCE(?, CURRENT_TIMESTAMP)) - ''', (conversation_id, message["sender"], message["message"], message.get("timestamp"))) - - conn.commit() - - # Create updated HTML preview - html_preview = "
" - for msg in conversation_data["messages"]: - sender_style = "background-color: #e6f3ff;" if msg[ - 'sender'] == 'user' else "background-color: #f0f0f0;" - html_preview += f"
" - html_preview += f"{msg['sender']}: {html.escape(msg['message'])}
" - html_preview += f"Timestamp: {msg.get('timestamp', 'N/A')}" - html_preview += "
" - html_preview += "
" - - return "Conversation updated successfully.", html_preview - except sqlite3.Error as e: - conn.rollback() - logging.error(f"Database error in save_conversation: {e}") - return f"Error updating conversation: {str(e)}", "

Error occurred while saving

" - except Exception as e: - conn.rollback() - logging.error(f"Unexpected error in save_conversation: {e}") - return f"Unexpected error: {str(e)}", "

Unexpected error occurred

" - - def delete_conversation(selected, conversation_mapping): - if not selected or selected not in conversation_mapping: - return "Please select a conversation before deleting.", "

No changes made

", gr.update(choices=[]) - - conversation_id = conversation_mapping[selected] - - try: - with db.get_connection() as conn: - cursor = conn.cursor() - - # Delete messages associated with the conversation - cursor.execute("DELETE FROM ChatMessages WHERE conversation_id = ?", (conversation_id,)) - - # Delete the conversation itself - cursor.execute("DELETE FROM ChatConversations WHERE id = ?", (conversation_id,)) - - conn.commit() - - # Update the conversation list - remaining_conversations = [choice for choice in conversation_mapping.keys() if choice != selected] - updated_mapping = {choice: conversation_mapping[choice] for choice in remaining_conversations} - - return "Conversation deleted successfully.", "

Conversation deleted

", gr.update(choices=remaining_conversations) - except sqlite3.Error as e: - conn.rollback() - logging.error(f"Database error in delete_conversation: {e}") - return f"Error deleting conversation: {str(e)}", "

Error occurred while deleting

", gr.update() - except Exception as e: - conn.rollback() - logging.error(f"Unexpected error in delete_conversation: {e}") - return f"Unexpected error: {str(e)}", "

Unexpected error occurred

", gr.update() - - def parse_formatted_content(formatted_content): - lines = formatted_content.split('\n') - conversation_id = int(lines[0].split(': ')[1]) - timestamp = lines[1].split(': ')[1] - history = [] - current_role = None - current_content = None - for line in lines[3:]: - if line.startswith("Role: "): - if current_role is not None: - history.append({"role": current_role, "content": ["", current_content]}) - current_role = line.split(': ')[1] - elif line.startswith("Content: "): - current_content = line.split(': ', 1)[1] - if current_role is not None: - history.append({"role": current_role, "content": ["", current_content]}) - return json.dumps({ - "conversation_id": conversation_id, - "timestamp": timestamp, - "history": history - }, indent=2) - - search_button.click( - search_conversations, - inputs=[search_query], - outputs=[conversation_list, conversation_mapping] - ) - - conversation_list.change( - load_conversations, - inputs=[conversation_list, conversation_mapping], - outputs=[chat_content, chat_preview] - ) - - save_button.click( - save_conversation, - inputs=[conversation_list, conversation_mapping, chat_content], - outputs=[result_message, chat_preview] - ) - - delete_button.click( - delete_conversation, - inputs=[conversation_list, conversation_mapping], - outputs=[result_message, chat_preview, conversation_list] - ) - - return search_query, search_button, conversation_list, conversation_mapping, chat_content, save_button, delete_button, result_message, chat_preview - - - # Mock function to simulate LLM processing def process_with_llm(workflow, context, prompt, api_endpoint, api_key): api_key_snippet = api_key[:5] + "..." if api_key else "Not provided" return f"LLM output using {api_endpoint} (API Key: {api_key_snippet}) for {workflow} with context: {context[:30]}... and prompt: {prompt[:30]}..." - # # End of Chat_ui.py ####################################################################################################################### \ No newline at end of file diff --git a/App_Function_Libraries/Gradio_UI/Embeddings_tab.py b/App_Function_Libraries/Gradio_UI/Embeddings_tab.py index a12e10320..fb6482ee2 100644 --- a/App_Function_Libraries/Gradio_UI/Embeddings_tab.py +++ b/App_Function_Libraries/Gradio_UI/Embeddings_tab.py @@ -4,6 +4,7 @@ # Imports import json import logging +import os # # External Imports import gradio as gr @@ -11,26 +12,58 @@ from tqdm import tqdm # # Local Imports -from App_Function_Libraries.DB.DB_Manager import get_all_content_from_database +from App_Function_Libraries.DB.DB_Manager import get_all_content_from_database, get_all_conversations, \ + get_conversation_text, get_note_by_id +from App_Function_Libraries.DB.RAG_QA_Chat_DB import get_all_notes from App_Function_Libraries.RAG.ChromaDB_Library import chroma_client, \ store_in_chroma, situate_context from App_Function_Libraries.RAG.Embeddings_Create import create_embedding, create_embeddings_batch from App_Function_Libraries.Chunk_Lib import improved_chunking_process, chunk_for_embedding +from App_Function_Libraries.Utils.Utils import load_and_log_configs + + # ######################################################################################################################## # # Functions: def create_embeddings_tab(): + # Load configuration first + config = load_and_log_configs() + if not config: + raise ValueError("Could not load configuration") + + # Get database paths from config + db_config = config['db_config'] + media_db_path = db_config['sqlite_path'] + rag_qa_db_path = os.path.join(os.path.dirname(media_db_path), "rag_qa.db") + character_chat_db_path = os.path.join(os.path.dirname(media_db_path), "chatDB.db") + chroma_db_path = db_config['chroma_db_path'] + with gr.TabItem("Create Embeddings", visible=True): gr.Markdown("# Create Embeddings for All Content") with gr.Row(): with gr.Column(): + # Database selection at the top + database_selection = gr.Radio( + choices=["Media DB", "RAG Chat", "Character Chat"], + label="Select Content Source", + value="Media DB", + info="Choose which database to create embeddings from" + ) + + # Add database path display + current_db_path = gr.Textbox( + label="Current Database Path", + value=media_db_path, + interactive=False + ) + embedding_provider = gr.Radio( choices=["huggingface", "local", "openai"], label="Select Embedding Provider", - value="huggingface" + value=config['embedding_config']['embedding_provider'] or "huggingface" ) gr.Markdown("Note: Local provider requires a running Llama.cpp/llamafile server.") gr.Markdown("OpenAI provider requires a valid API key.") @@ -65,22 +98,24 @@ def create_embeddings_tab(): embedding_api_url = gr.Textbox( label="API URL (for local provider)", - value="http://localhost:8080/embedding", + value=config['embedding_config']['embedding_api_url'], visible=False ) - # Add chunking options + # Add chunking options with config defaults chunking_method = gr.Dropdown( choices=["words", "sentences", "paragraphs", "tokens", "semantic"], label="Chunking Method", value="words" ) max_chunk_size = gr.Slider( - minimum=1, maximum=8000, step=1, value=500, + minimum=1, maximum=8000, step=1, + value=config['embedding_config']['chunk_size'], label="Max Chunk Size" ) chunk_overlap = gr.Slider( - minimum=0, maximum=4000, step=1, value=200, + minimum=0, maximum=4000, step=1, + value=config['embedding_config']['overlap'], label="Chunk Overlap" ) adaptive_chunking = gr.Checkbox( @@ -92,6 +127,7 @@ def create_embeddings_tab(): with gr.Column(): status_output = gr.Textbox(label="Status", lines=10) + progress = gr.Progress() def update_provider_options(provider): if provider == "huggingface": @@ -107,23 +143,54 @@ def update_huggingface_options(model): else: return gr.update(visible=False) - embedding_provider.change( - fn=update_provider_options, - inputs=[embedding_provider], - outputs=[huggingface_model, openai_model, custom_embedding_model, embedding_api_url] - ) - - huggingface_model.change( - fn=update_huggingface_options, - inputs=[huggingface_model], - outputs=[custom_embedding_model] - ) + def update_database_path(database_type): + if database_type == "Media DB": + return media_db_path + elif database_type == "RAG Chat": + return rag_qa_db_path + else: # Character Chat + return character_chat_db_path - def create_all_embeddings(provider, hf_model, openai_model, custom_model, api_url, method, max_size, overlap, adaptive): + def create_all_embeddings(provider, hf_model, openai_model, custom_model, api_url, method, + max_size, overlap, adaptive, database_type, progress=gr.Progress()): try: - all_content = get_all_content_from_database() + # Initialize content based on database selection + if database_type == "Media DB": + all_content = get_all_content_from_database() + content_type = "media" + elif database_type == "RAG Chat": + all_content = [] + page = 1 + while True: + conversations, total_pages, _ = get_all_conversations(page=page) + if not conversations: + break + all_content.extend([{ + 'id': conv['conversation_id'], + 'content': get_conversation_text(conv['conversation_id']), + 'title': conv['title'], + 'type': 'conversation' + } for conv in conversations]) + progress(page / total_pages, desc=f"Loading conversations... Page {page}/{total_pages}") + page += 1 + else: # Character Chat + all_content = [] + page = 1 + while True: + notes, total_pages, _ = get_all_notes(page=page) + if not notes: + break + all_content.extend([{ + 'id': note['id'], + 'content': f"{note['title']}\n\n{note['content']}", + 'conversation_id': note['conversation_id'], + 'type': 'note' + } for note in notes]) + progress(page / total_pages, desc=f"Loading notes... Page {page}/{total_pages}") + page += 1 + if not all_content: - return "No content found in the database." + return "No content found in the selected database." chunk_options = { 'method': method, @@ -132,7 +199,7 @@ def create_all_embeddings(provider, hf_model, openai_model, custom_model, api_ur 'adaptive': adaptive } - collection_name = "all_content_embeddings" + collection_name = f"{database_type.lower().replace(' ', '_')}_embeddings" collection = chroma_client.get_or_create_collection(name=collection_name) # Determine the model to use @@ -141,55 +208,113 @@ def create_all_embeddings(provider, hf_model, openai_model, custom_model, api_ur elif provider == "openai": model = openai_model else: - model = custom_model + model = api_url + + total_items = len(all_content) + for idx, item in enumerate(all_content): + progress((idx + 1) / total_items, desc=f"Processing item {idx + 1} of {total_items}") - for item in all_content: - media_id = item['id'] + content_id = item['id'] text = item['content'] chunks = improved_chunking_process(text, chunk_options) - for i, chunk in enumerate(chunks): + for chunk_idx, chunk in enumerate(chunks): chunk_text = chunk['text'] - chunk_id = f"doc_{media_id}_chunk_{i}" - - existing = collection.get(ids=[chunk_id]) - if existing['ids']: + chunk_id = f"{database_type.lower()}_{content_id}_chunk_{chunk_idx}" + + try: + embedding = create_embedding(chunk_text, provider, model, api_url) + metadata = { + 'content_id': str(content_id), + 'chunk_index': int(chunk_idx), + 'total_chunks': int(len(chunks)), + 'chunking_method': method, + 'max_chunk_size': int(max_size), + 'chunk_overlap': int(overlap), + 'adaptive_chunking': bool(adaptive), + 'embedding_model': model, + 'embedding_provider': provider, + 'content_type': item.get('type', 'media'), + 'conversation_id': item.get('conversation_id'), + **{k: (int(v) if isinstance(v, str) and v.isdigit() else v) + for k, v in chunk['metadata'].items()} + } + store_in_chroma(collection_name, [chunk_text], [embedding], [chunk_id], [metadata]) + + except Exception as e: + logging.error(f"Error processing chunk {chunk_id}: {str(e)}") continue - embedding = create_embedding(chunk_text, provider, model, api_url) - metadata = { - "media_id": str(media_id), - "chunk_index": i, - "total_chunks": len(chunks), - "chunking_method": method, - "max_chunk_size": max_size, - "chunk_overlap": overlap, - "adaptive_chunking": adaptive, - "embedding_model": model, - "embedding_provider": provider, - **chunk['metadata'] - } - store_in_chroma(collection_name, [chunk_text], [embedding], [chunk_id], [metadata]) - - return "Embeddings created and stored successfully for all content." + return f"Embeddings created and stored successfully for all {database_type} content." except Exception as e: logging.error(f"Error during embedding creation: {str(e)}") return f"Error: {str(e)}" + # Event handlers + embedding_provider.change( + fn=update_provider_options, + inputs=[embedding_provider], + outputs=[huggingface_model, openai_model, custom_embedding_model, embedding_api_url] + ) + + huggingface_model.change( + fn=update_huggingface_options, + inputs=[huggingface_model], + outputs=[custom_embedding_model] + ) + + database_selection.change( + fn=update_database_path, + inputs=[database_selection], + outputs=[current_db_path] + ) + create_button.click( fn=create_all_embeddings, - inputs=[embedding_provider, huggingface_model, openai_model, custom_embedding_model, embedding_api_url, - chunking_method, max_chunk_size, chunk_overlap, adaptive_chunking], + inputs=[ + embedding_provider, huggingface_model, openai_model, custom_embedding_model, + embedding_api_url, chunking_method, max_chunk_size, chunk_overlap, + adaptive_chunking, database_selection + ], outputs=status_output ) def create_view_embeddings_tab(): + # Load configuration first + config = load_and_log_configs() + if not config: + raise ValueError("Could not load configuration") + + # Get database paths from config + db_config = config['db_config'] + media_db_path = db_config['sqlite_path'] + rag_qa_db_path = os.path.join(os.path.dirname(media_db_path), "rag_chat.db") + character_chat_db_path = os.path.join(os.path.dirname(media_db_path), "character_chat.db") + chroma_db_path = db_config['chroma_db_path'] + with gr.TabItem("View/Update Embeddings", visible=True): gr.Markdown("# View and Update Embeddings") - item_mapping = gr.State({}) + # Initialize item_mapping as a Gradio State + + with gr.Row(): with gr.Column(): + # Add database selection + database_selection = gr.Radio( + choices=["Media DB", "RAG Chat", "Character Chat"], + label="Select Content Source", + value="Media DB", + info="Choose which database to view embeddings from" + ) + + # Add database path display + current_db_path = gr.Textbox( + label="Current Database Path", + value=media_db_path, + interactive=False + ) + item_dropdown = gr.Dropdown(label="Select Item", choices=[], interactive=True) refresh_button = gr.Button("Refresh Item List") embedding_status = gr.Textbox(label="Embedding Status", interactive=False) @@ -236,9 +361,10 @@ def create_view_embeddings_tab(): embedding_api_url = gr.Textbox( label="API URL (for local provider)", - value="http://localhost:8080/embedding", + value=config['embedding_config']['embedding_api_url'], visible=False ) + chunking_method = gr.Dropdown( choices=["words", "sentences", "paragraphs", "tokens", "semantic"], label="Chunking Method", @@ -267,15 +393,45 @@ def create_view_embeddings_tab(): ) contextual_api_key = gr.Textbox(label="API Key", lines=1) - def get_items_with_embedding_status(): + item_mapping = gr.State(value={}) + + def update_database_path(database_type): + if database_type == "Media DB": + return media_db_path + elif database_type == "RAG Chat": + return rag_qa_db_path + else: # Character Chat + return character_chat_db_path + + def get_items_with_embedding_status(database_type): try: - items = get_all_content_from_database() - collection = chroma_client.get_or_create_collection(name="all_content_embeddings") + # Get items based on database selection + if database_type == "Media DB": + items = get_all_content_from_database() + elif database_type == "RAG Chat": + conversations, _, _ = get_all_conversations(page=1) + items = [{ + 'id': conv['conversation_id'], + 'title': conv['title'], + 'type': 'conversation' + } for conv in conversations] + else: # Character Chat + notes, _, _ = get_all_notes(page=1) + items = [{ + 'id': note['id'], + 'title': note['title'], + 'type': 'note' + } for note in notes] + + collection_name = f"{database_type.lower().replace(' ', '_')}_embeddings" + collection = chroma_client.get_or_create_collection(name=collection_name) + choices = [] new_item_mapping = {} for item in items: try: - result = collection.get(ids=[f"doc_{item['id']}_chunk_0"]) + chunk_id = f"{database_type.lower()}_{item['id']}_chunk_0" + result = collection.get(ids=[chunk_id]) embedding_exists = result is not None and result.get('ids') and len(result['ids']) > 0 status = "Embedding exists" if embedding_exists else "No embedding" except Exception as e: @@ -303,40 +459,62 @@ def update_huggingface_options(model): else: return gr.update(visible=False) - def check_embedding_status(selected_item, item_mapping): + def check_embedding_status(selected_item, database_type, item_mapping): if not selected_item: return "Please select an item", "", "" + if item_mapping is None: + # If mapping is None, try to refresh it + try: + _, item_mapping = get_items_with_embedding_status(database_type) + except Exception as e: + return f"Error initializing item mapping: {str(e)}", "", "" + try: item_id = item_mapping.get(selected_item) if item_id is None: return f"Invalid item selected: {selected_item}", "", "" item_title = selected_item.rsplit(' (', 1)[0] - collection = chroma_client.get_or_create_collection(name="all_content_embeddings") + collection_name = f"{database_type.lower().replace(' ', '_')}_embeddings" + collection = chroma_client.get_or_create_collection(name=collection_name) + chunk_id = f"{database_type.lower()}_{item_id}_chunk_0" + + try: + result = collection.get(ids=[chunk_id], include=["embeddings", "metadatas"]) + except Exception as e: + logging.error(f"ChromaDB get error: {str(e)}") + return f"Error retrieving embedding for '{item_title}': {str(e)}", "", "" - result = collection.get(ids=[f"doc_{item_id}_chunk_0"], include=["embeddings", "metadatas"]) - logging.info(f"ChromaDB result for item '{item_title}' (ID: {item_id}): {result}") + # Check if result exists and has the expected structure + if not result or not isinstance(result, dict): + return f"No embedding found for item '{item_title}' (ID: {item_id})", "", "" - if not result['ids']: + # Check if we have any results + if not result.get('ids') or len(result['ids']) == 0: return f"No embedding found for item '{item_title}' (ID: {item_id})", "", "" - if not result['embeddings'] or not result['embeddings'][0]: + # Check if embeddings exist + if not result.get('embeddings') or not result['embeddings'][0]: return f"Embedding data missing for item '{item_title}' (ID: {item_id})", "", "" embedding = result['embeddings'][0] - metadata = result['metadatas'][0] if result['metadatas'] else {} + metadata = result.get('metadatas', [{}])[0] if result.get('metadatas') else {} embedding_preview = str(embedding[:50]) status = f"Embedding exists for item '{item_title}' (ID: {item_id})" return status, f"First 50 elements of embedding:\n{embedding_preview}", json.dumps(metadata, indent=2) except Exception as e: - logging.error(f"Error in check_embedding_status: {str(e)}") + logging.error(f"Error in check_embedding_status: {str(e)}", exc_info=True) return f"Error processing item: {selected_item}. Details: {str(e)}", "", "" - def create_new_embedding_for_item(selected_item, provider, hf_model, openai_model, custom_model, api_url, - method, max_size, overlap, adaptive, - item_mapping, use_contextual, contextual_api_choice=None): + def refresh_and_update(database_type): + choices_update, new_mapping = get_items_with_embedding_status(database_type) + return choices_update, new_mapping + + def create_new_embedding_for_item(selected_item, database_type, provider, hf_model, openai_model, + custom_model, api_url, method, max_size, overlap, adaptive, + item_mapping, use_contextual, contextual_api_choice=None): if not selected_item: return "Please select an item", "", "" @@ -345,8 +523,26 @@ def create_new_embedding_for_item(selected_item, provider, hf_model, openai_mode if item_id is None: return f"Invalid item selected: {selected_item}", "", "" - items = get_all_content_from_database() - item = next((item for item in items if item['id'] == item_id), None) + # Get item content based on database type + if database_type == "Media DB": + items = get_all_content_from_database() + item = next((item for item in items if item['id'] == item_id), None) + elif database_type == "RAG Chat": + item = { + 'id': item_id, + 'content': get_conversation_text(item_id), + 'title': selected_item.rsplit(' (', 1)[0], + 'type': 'conversation' + } + else: # Character Chat + note = get_note_by_id(item_id) + item = { + 'id': item_id, + 'content': f"{note['title']}\n\n{note['content']}", + 'title': note['title'], + 'type': 'note' + } + if not item: return f"Item not found: {item_id}", "", "" @@ -359,11 +555,11 @@ def create_new_embedding_for_item(selected_item, provider, hf_model, openai_mode logging.info(f"Chunking content for item: {item['title']} (ID: {item_id})") chunks = chunk_for_embedding(item['content'], item['title'], chunk_options) - collection_name = "all_content_embeddings" + collection_name = f"{database_type.lower().replace(' ', '_')}_embeddings" collection = chroma_client.get_or_create_collection(name=collection_name) # Delete existing embeddings for this item - existing_ids = [f"doc_{item_id}_chunk_{i}" for i in range(len(chunks))] + existing_ids = [f"{database_type.lower()}_{item_id}_chunk_{i}" for i in range(len(chunks))] collection.delete(ids=existing_ids) logging.info(f"Deleted {len(existing_ids)} existing embeddings for item {item_id}") @@ -381,7 +577,7 @@ def create_new_embedding_for_item(selected_item, provider, hf_model, openai_mode contextualized_text = chunk_text context = None - chunk_id = f"doc_{item_id}_chunk_{i}" + chunk_id = f"{database_type.lower()}_{item_id}_chunk_{i}" # Determine the model to use if provider == "huggingface": @@ -392,7 +588,7 @@ def create_new_embedding_for_item(selected_item, provider, hf_model, openai_mode model = custom_model metadata = { - "media_id": str(item_id), + "content_id": str(item_id), "chunk_index": i, "total_chunks": len(chunks), "chunking_method": method, @@ -441,15 +637,25 @@ def create_new_embedding_for_item(selected_item, provider, hf_model, openai_mode logging.error(f"Error in create_new_embedding_for_item: {str(e)}", exc_info=True) return f"Error creating embedding: {str(e)}", "", "" + # Wire up all the event handlers + database_selection.change( + update_database_path, + inputs=[database_selection], + outputs=[current_db_path] + ) + refresh_button.click( get_items_with_embedding_status, + inputs=[database_selection], outputs=[item_dropdown, item_mapping] ) + item_dropdown.change( check_embedding_status, - inputs=[item_dropdown, item_mapping], + inputs=[item_dropdown, database_selection, item_mapping], outputs=[embedding_status, embedding_preview, embedding_metadata] ) + create_new_embedding_button.click( create_new_embedding_for_item, inputs=[item_dropdown, embedding_provider, huggingface_model, openai_model, custom_embedding_model, embedding_api_url, @@ -469,9 +675,10 @@ def create_new_embedding_for_item(selected_item, provider, hf_model, openai_mode ) return (item_dropdown, refresh_button, embedding_status, embedding_preview, embedding_metadata, - create_new_embedding_button, embedding_provider, huggingface_model, openai_model, custom_embedding_model, embedding_api_url, - chunking_method, max_chunk_size, chunk_overlap, adaptive_chunking, - use_contextual_embeddings, contextual_api_choice, contextual_api_key) + create_new_embedding_button, embedding_provider, huggingface_model, openai_model, + custom_embedding_model, embedding_api_url, chunking_method, max_chunk_size, + chunk_overlap, adaptive_chunking, use_contextual_embeddings, + contextual_api_choice, contextual_api_key) def create_purge_embeddings_tab(): diff --git a/App_Function_Libraries/Gradio_UI/Explain_summarize_tab.py b/App_Function_Libraries/Gradio_UI/Explain_summarize_tab.py index fbf5505ac..50942fa91 100644 --- a/App_Function_Libraries/Gradio_UI/Explain_summarize_tab.py +++ b/App_Function_Libraries/Gradio_UI/Explain_summarize_tab.py @@ -7,7 +7,7 @@ # External Imports import gradio as gr -from App_Function_Libraries.DB.DB_Manager import load_preset_prompts +from App_Function_Libraries.DB.DB_Manager import list_prompts from App_Function_Libraries.Gradio_UI.Gradio_Shared import update_user_prompt # # Local Imports @@ -37,32 +37,52 @@ def create_summarize_explain_tab(): except Exception as e: logging.error(f"Error setting default API endpoint: {str(e)}") default_value = None + with gr.TabItem("Analyze Text", visible=True): gr.Markdown("# Analyze / Explain / Summarize Text without ingesting it into the DB") + + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + with gr.Row(): with gr.Column(): with gr.Row(): - text_to_work_input = gr.Textbox(label="Text to be Explained or Summarized", - placeholder="Enter the text you want explained or summarized here", - lines=20) + text_to_work_input = gr.Textbox( + label="Text to be Explained or Summarized", + placeholder="Enter the text you want explained or summarized here", + lines=20 + ) with gr.Row(): explanation_checkbox = gr.Checkbox(label="Explain Text", value=True) summarization_checkbox = gr.Checkbox(label="Summarize Text", value=True) - custom_prompt_checkbox = gr.Checkbox(label="Use a Custom Prompt", - value=False, - visible=True) - preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", - value=False, - visible=True) + custom_prompt_checkbox = gr.Checkbox( + label="Use a Custom Prompt", + value=False, + visible=True + ) + preset_prompt_checkbox = gr.Checkbox( + label="Use a pre-set Prompt", + value=False, + visible=True + ) with gr.Row(): - preset_prompt = gr.Dropdown(label="Select Preset Prompt", - choices=load_preset_prompts(), - visible=False) + # Add pagination controls + preset_prompt = gr.Dropdown( + label="Select Preset Prompt", + choices=[], + visible=False + ) + prev_page_button = gr.Button("Previous Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + next_page_button = gr.Button("Next Page", visible=False) with gr.Row(): - custom_prompt_input = gr.Textbox(label="Custom Prompt", - placeholder="Enter custom prompt here", - lines=3, - visible=False) + custom_prompt_input = gr.Textbox( + label="Custom Prompt", + placeholder="Enter custom prompt here", + lines=10, + visible=False + ) with gr.Row(): system_prompt_input = gr.Textbox(label="System Prompt", value="""You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] @@ -82,7 +102,7 @@ def create_summarize_explain_tab(): - Ensure adherence to specified format - Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST] """, - lines=3, + lines=10, visible=False, interactive=True) # Refactored API selection dropdown @@ -92,8 +112,11 @@ def create_summarize_explain_tab(): label="API for Summarization/Analysis (Optional)" ) with gr.Row(): - api_key_input = gr.Textbox(label="API Key (if required)", placeholder="Enter your API key here", - type="password") + api_key_input = gr.Textbox( + label="API Key (if required)", + placeholder="Enter your API key here", + type="password" + ) with gr.Row(): explain_summarize_button = gr.Button("Explain/Summarize") @@ -102,17 +125,83 @@ def create_summarize_explain_tab(): explanation_output = gr.Textbox(label="Explanation:", lines=20) custom_prompt_output = gr.Textbox(label="Custom Prompt:", lines=20, visible=True) + # Handle custom prompt checkbox change custom_prompt_checkbox.change( fn=lambda x: (gr.update(visible=x), gr.update(visible=x)), inputs=[custom_prompt_checkbox], outputs=[custom_prompt_input, system_prompt_input] ) + + # Handle preset prompt checkbox change + def on_preset_prompt_checkbox_change(is_checked): + if is_checked: + prompts, total_pages, current_page = list_prompts(page=1, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(visible=True, interactive=True, choices=prompts), # preset_prompt + gr.update(visible=True), # prev_page_button + gr.update(visible=True), # next_page_button + gr.update(value=page_display_text, visible=True), # page_display + current_page, # current_page_state + total_pages # total_pages_state + ) + else: + return ( + gr.update(visible=False, interactive=False), # preset_prompt + gr.update(visible=False), # prev_page_button + gr.update(visible=False), # next_page_button + gr.update(visible=False), # page_display + 1, # current_page_state + 1 # total_pages_state + ) + preset_prompt_checkbox.change( - fn=lambda x: gr.update(visible=x), + fn=on_preset_prompt_checkbox_change, inputs=[preset_prompt_checkbox], - outputs=[preset_prompt] + outputs=[ + preset_prompt, + prev_page_button, + next_page_button, + page_display, + current_page_state, + total_pages_state + ] ) + # Pagination button functions + def on_prev_page_click(current_page, total_pages): + new_page = max(current_page - 1, 1) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text), + current_page + ) + + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text), + current_page + ) + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + # Update prompts when a preset is selected def update_prompts(preset_name): prompts = update_user_prompt(preset_name) return ( @@ -121,18 +210,27 @@ def update_prompts(preset_name): ) preset_prompt.change( - update_prompts, - inputs=preset_prompt, + fn=update_prompts, + inputs=[preset_prompt], outputs=[custom_prompt_input, system_prompt_input] ) explain_summarize_button.click( fn=summarize_explain_text, - inputs=[text_to_work_input, api_endpoint, api_key_input, summarization_checkbox, explanation_checkbox, custom_prompt_input, system_prompt_input], + inputs=[ + text_to_work_input, + api_endpoint, + api_key_input, + summarization_checkbox, + explanation_checkbox, + custom_prompt_input, + system_prompt_input + ], outputs=[summarization_output, explanation_output, custom_prompt_output] ) + def summarize_explain_text(message, api_endpoint, api_key, summarization, explanation, custom_prompt, custom_system_prompt,): global custom_prompt_output summarization_response = None diff --git a/App_Function_Libraries/Gradio_UI/Export_Functionality.py b/App_Function_Libraries/Gradio_UI/Export_Functionality.py index 540f19ac5..806538585 100644 --- a/App_Function_Libraries/Gradio_UI/Export_Functionality.py +++ b/App_Function_Libraries/Gradio_UI/Export_Functionality.py @@ -6,9 +6,11 @@ import logging import shutil import tempfile -from typing import List, Dict, Optional, Tuple +from typing import List, Dict, Optional, Tuple, Any import gradio as gr -from App_Function_Libraries.DB.DB_Manager import DatabaseError +from App_Function_Libraries.DB.DB_Manager import DatabaseError, fetch_all_notes, fetch_all_conversations, \ + get_keywords_for_note, fetch_notes_by_ids, fetch_conversations_by_ids +from App_Function_Libraries.DB.RAG_QA_Chat_DB import get_keywords_for_conversation from App_Function_Libraries.Gradio_UI.Gradio_Shared import fetch_item_details, fetch_items_by_keyword, browse_items logger = logging.getLogger(__name__) @@ -36,7 +38,7 @@ def export_items_by_keyword(keyword: str) -> str: items = fetch_items_by_keyword(keyword) if not items: logger.warning(f"No items found for keyword: {keyword}") - return None + return f"No items found for keyword: {keyword}" # Create a temporary directory to store individual markdown files with tempfile.TemporaryDirectory() as temp_dir: @@ -66,7 +68,7 @@ def export_items_by_keyword(keyword: str) -> str: return final_zip_path except Exception as e: logger.error(f"Error exporting items for keyword '{keyword}': {str(e)}") - return None + return f"Error exporting items for keyword '{keyword}': {str(e)}" def export_selected_items(selected_items: List[Dict]) -> Tuple[Optional[str], str]: @@ -146,121 +148,747 @@ def display_search_results_export_tab(search_query: str, search_type: str, page: logger.error(error_message) return [], error_message, 1, 1 +# +# End of Media DB Export functionality +################################################################ -def create_export_tab(): - with gr.Tab("Search and Export"): - with gr.Row(): - with gr.Column(): - gr.Markdown("# Search and Export Items") - gr.Markdown("Search for items and export them as markdown files") - gr.Markdown("You can also export items by keyword") - search_query = gr.Textbox(label="Search Query") - search_type = gr.Radio(["Title", "URL", "Keyword", "Content"], label="Search By") - search_button = gr.Button("Search") - - with gr.Column(): - prev_button = gr.Button("Previous Page") - next_button = gr.Button("Next Page") - - current_page = gr.State(1) - total_pages = gr.State(1) - - search_results = gr.CheckboxGroup(label="Search Results", choices=[]) - export_selected_button = gr.Button("Export Selected Items") - - keyword_input = gr.Textbox(label="Enter keyword for export") - export_by_keyword_button = gr.Button("Export items by keyword") - - export_output = gr.File(label="Download Exported File") - error_output = gr.Textbox(label="Status/Error Messages", interactive=False) - - def search_and_update(query, search_type, page): - results, message, current, total = display_search_results_export_tab(query, search_type, page) - logger.debug(f"search_and_update results: {results}") - return results, message, current, total, gr.update(choices=results) - - search_button.click( - fn=search_and_update, - inputs=[search_query, search_type, current_page], - outputs=[search_results, error_output, current_page, total_pages, search_results], - show_progress="full" - ) - - - def update_page(current, total, direction): - new_page = max(1, min(total, current + direction)) - return new_page - - prev_button.click( - fn=update_page, - inputs=[current_page, total_pages, gr.State(-1)], - outputs=[current_page] - ).then( - fn=search_and_update, - inputs=[search_query, search_type, current_page], - outputs=[search_results, error_output, current_page, total_pages], - show_progress=True - ) - - next_button.click( - fn=update_page, - inputs=[current_page, total_pages, gr.State(1)], - outputs=[current_page] - ).then( - fn=search_and_update, - inputs=[search_query, search_type, current_page], - outputs=[search_results, error_output, current_page, total_pages], - show_progress=True - ) - - def handle_export_selected(selected_items): - logger.debug(f"Exporting selected items: {selected_items}") - return export_selected_items(selected_items) - - export_selected_button.click( - fn=handle_export_selected, - inputs=[search_results], - outputs=[export_output, error_output], - show_progress="full" - ) - - export_by_keyword_button.click( - fn=export_items_by_keyword, - inputs=[keyword_input], - outputs=[export_output, error_output], - show_progress="full" - ) - - def handle_item_selection(selected_items): - logger.debug(f"Selected items: {selected_items}") - if not selected_items: - return None, "No item selected" - - try: - # Assuming selected_items is a list of dictionaries - selected_item = selected_items[0] - logger.debug(f"First selected item: {selected_item}") - - # Check if 'value' is a string (JSON) or already a dictionary - if isinstance(selected_item['value'], str): - item_data = json.loads(selected_item['value']) - else: - item_data = selected_item['value'] - - logger.debug(f"Item data: {item_data}") - - item_id = item_data['id'] - return export_item_as_markdown(item_id) - except Exception as e: - error_message = f"Error processing selected item: {str(e)}" - logger.error(error_message) - return None, error_message - - search_results.select( - fn=handle_item_selection, - inputs=[search_results], - outputs=[export_output, error_output], - show_progress="full" - ) +################################################################ +# +# Functions for RAG Chat DB Export functionality + + +def export_rag_conversations_as_json( + selected_conversations: Optional[List[Dict[str, Any]]] = None +) -> Tuple[Optional[str], str]: + """ + Export conversations to a JSON file. + + Args: + selected_conversations: Optional list of conversation dictionaries + + Returns: + Tuple of (filename or None, status message) + """ + try: + if selected_conversations: + # Extract conversation IDs from selected items + conversation_ids = [] + for item in selected_conversations: + if isinstance(item, str): + item_data = json.loads(item) + elif isinstance(item, dict) and 'value' in item: + item_data = item['value'] if isinstance(item['value'], dict) else json.loads(item['value']) + else: + item_data = item + conversation_ids.append(item_data['conversation_id']) + + conversations = fetch_conversations_by_ids(conversation_ids) + else: + conversations = fetch_all_conversations() + + export_data = [] + for conversation_id, title, messages in conversations: + # Get keywords for the conversation + keywords = get_keywords_for_conversation(conversation_id) + + conversation_data = { + "conversation_id": conversation_id, + "title": title, + "keywords": keywords, + "messages": [ + {"role": role, "content": content} + for role, content in messages + ] + } + export_data.append(conversation_data) + + filename = "rag_conversations_export.json" + with open(filename, "w", encoding='utf-8') as f: + json.dump(export_data, f, indent=2, ensure_ascii=False) + + logger.info(f"Successfully exported {len(export_data)} conversations to {filename}") + return filename, f"Successfully exported {len(export_data)} conversations to {filename}" + except Exception as e: + error_message = f"Error exporting conversations: {str(e)}" + logger.error(error_message) + return None, error_message + + +def export_rag_notes_as_json( + selected_notes: Optional[List[Dict[str, Any]]] = None +) -> Tuple[Optional[str], str]: + """ + Export notes to a JSON file. + + Args: + selected_notes: Optional list of note dictionaries + + Returns: + Tuple of (filename or None, status message) + """ + try: + if selected_notes: + # Extract note IDs from selected items + note_ids = [] + for item in selected_notes: + if isinstance(item, str): + item_data = json.loads(item) + elif isinstance(item, dict) and 'value' in item: + item_data = item['value'] if isinstance(item['value'], dict) else json.loads(item['value']) + else: + item_data = item + note_ids.append(item_data['id']) + + notes = fetch_notes_by_ids(note_ids) + else: + notes = fetch_all_notes() + + export_data = [] + for note_id, title, content in notes: + # Get keywords for the note + keywords = get_keywords_for_note(note_id) + + note_data = { + "note_id": note_id, + "title": title, + "content": content, + "keywords": keywords + } + export_data.append(note_data) + + filename = "rag_notes_export.json" + with open(filename, "w", encoding='utf-8') as f: + json.dump(export_data, f, indent=2, ensure_ascii=False) + + logger.info(f"Successfully exported {len(export_data)} notes to {filename}") + return filename, f"Successfully exported {len(export_data)} notes to {filename}" + except Exception as e: + error_message = f"Error exporting notes: {str(e)}" + logger.error(error_message) + return None, error_message + + +def display_rag_conversations(search_query: str = "", page: int = 1, items_per_page: int = 10): + """Display conversations for selection in the export tab.""" + try: + conversations = fetch_all_conversations() + + if search_query: + # Simple search implementation - can be enhanced based on needs + conversations = [ + conv for conv in conversations + if search_query.lower() in conv[1].lower() # Search in title + ] + + # Implement pagination + start_idx = (page - 1) * items_per_page + end_idx = start_idx + items_per_page + paginated_conversations = conversations[start_idx:end_idx] + total_pages = (len(conversations) + items_per_page - 1) // items_per_page + + # Format for checkbox group + checkbox_data = [ + { + "name": f"Title: {title}\nMessages: {len(messages)}", + "value": {"conversation_id": conv_id, "title": title} + } + for conv_id, title, messages in paginated_conversations + ] + + return ( + checkbox_data, + f"Found {len(conversations)} conversations (showing page {page} of {total_pages})", + page, + total_pages + ) + except Exception as e: + error_message = f"Error displaying conversations: {str(e)}" + logger.error(error_message) + return [], error_message, 1, 1 + + +def display_rag_notes(search_query: str = "", page: int = 1, items_per_page: int = 10): + """Display notes for selection in the export tab.""" + try: + notes = fetch_all_notes() + + if search_query: + # Simple search implementation - can be enhanced based on needs + notes = [ + note for note in notes + if search_query.lower() in note[1].lower() # Search in title + or search_query.lower() in note[2].lower() # Search in content + ] + + # Implement pagination + start_idx = (page - 1) * items_per_page + end_idx = start_idx + items_per_page + paginated_notes = notes[start_idx:end_idx] + total_pages = (len(notes) + items_per_page - 1) // items_per_page + + # Format for checkbox group + checkbox_data = [ + { + "name": f"Title: {title}\nContent preview: {content[:100]}...", + "value": {"id": note_id, "title": title} + } + for note_id, title, content in paginated_notes + ] + + return ( + checkbox_data, + f"Found {len(notes)} notes (showing page {page} of {total_pages})", + page, + total_pages + ) + except Exception as e: + error_message = f"Error displaying notes: {str(e)}" + logger.error(error_message) + return [], error_message, 1, 1 + + +def create_rag_export_tab(): + """Create the RAG QA Chat export tab interface.""" + with gr.Tab("RAG QA Chat Export"): + with gr.Tabs(): + # Conversations Export Tab + with gr.Tab("Export Conversations"): + with gr.Row(): + with gr.Column(): + gr.Markdown("## Export RAG QA Chat Conversations") + conversation_search = gr.Textbox(label="Search Conversations") + conversation_search_button = gr.Button("Search") + + with gr.Column(): + conversation_prev_button = gr.Button("Previous Page") + conversation_next_button = gr.Button("Next Page") + + conversation_current_page = gr.State(1) + conversation_total_pages = gr.State(1) + + conversation_results = gr.CheckboxGroup(label="Select Conversations to Export") + export_selected_conversations_button = gr.Button("Export Selected Conversations") + export_all_conversations_button = gr.Button("Export All Conversations") + + conversation_export_output = gr.File(label="Download Exported Conversations") + conversation_status = gr.Textbox(label="Status", interactive=False) + + # Notes Export Tab + with gr.Tab("Export Notes"): + with gr.Row(): + with gr.Column(): + gr.Markdown("## Export RAG QA Chat Notes") + notes_search = gr.Textbox(label="Search Notes") + notes_search_button = gr.Button("Search") + + with gr.Column(): + notes_prev_button = gr.Button("Previous Page") + notes_next_button = gr.Button("Next Page") + + notes_current_page = gr.State(1) + notes_total_pages = gr.State(1) + + notes_results = gr.CheckboxGroup(label="Select Notes to Export") + export_selected_notes_button = gr.Button("Export Selected Notes") + export_all_notes_button = gr.Button("Export All Notes") + + notes_export_output = gr.File(label="Download Exported Notes") + notes_status = gr.Textbox(label="Status", interactive=False) + + # Event handlers for conversations + def search_conversations(query, page): + return display_rag_conversations(query, page) + + conversation_search_button.click( + fn=search_conversations, + inputs=[conversation_search, conversation_current_page], + outputs=[conversation_results, conversation_status, conversation_current_page, conversation_total_pages] + ) + + def update_conversation_page(current, total, direction): + new_page = max(1, min(total, current + direction)) + return new_page + + conversation_prev_button.click( + fn=update_conversation_page, + inputs=[conversation_current_page, conversation_total_pages, gr.State(-1)], + outputs=[conversation_current_page] + ).then( + fn=search_conversations, + inputs=[conversation_search, conversation_current_page], + outputs=[conversation_results, conversation_status, conversation_current_page, conversation_total_pages] + ) + + conversation_next_button.click( + fn=update_conversation_page, + inputs=[conversation_current_page, conversation_total_pages, gr.State(1)], + outputs=[conversation_current_page] + ).then( + fn=search_conversations, + inputs=[conversation_search, conversation_current_page], + outputs=[conversation_results, conversation_status, conversation_current_page, conversation_total_pages] + ) + + export_selected_conversations_button.click( + fn=export_rag_conversations_as_json, + inputs=[conversation_results], + outputs=[conversation_export_output, conversation_status] + ) + + export_all_conversations_button.click( + fn=lambda: export_rag_conversations_as_json(), + outputs=[conversation_export_output, conversation_status] + ) + + # Event handlers for notes + def search_notes(query, page): + return display_rag_notes(query, page) + + notes_search_button.click( + fn=search_notes, + inputs=[notes_search, notes_current_page], + outputs=[notes_results, notes_status, notes_current_page, notes_total_pages] + ) + + def update_notes_page(current, total, direction): + new_page = max(1, min(total, current + direction)) + return new_page + + notes_prev_button.click( + fn=update_notes_page, + inputs=[notes_current_page, notes_total_pages, gr.State(-1)], + outputs=[notes_current_page] + ).then( + fn=search_notes, + inputs=[notes_search, notes_current_page], + outputs=[notes_results, notes_status, notes_current_page, notes_total_pages] + ) + + notes_next_button.click( + fn=update_notes_page, + inputs=[notes_current_page, notes_total_pages, gr.State(1)], + outputs=[notes_current_page] + ).then( + fn=search_notes, + inputs=[notes_search, notes_current_page], + outputs=[notes_results, notes_status, notes_current_page, notes_total_pages] + ) + + export_selected_notes_button.click( + fn=export_rag_notes_as_json, + inputs=[notes_results], + outputs=[notes_export_output, notes_status] + ) + + export_all_notes_button.click( + fn=lambda: export_rag_notes_as_json(), + outputs=[notes_export_output, notes_status] + ) + +# +# End of RAG Chat DB Export functionality +##################################################### + +def create_export_tabs(): + """Create the unified export interface with all export tabs.""" + with gr.Tabs(): + # Media DB Export Tab + with gr.Tab("Media DB Export"): + with gr.Row(): + with gr.Column(): + gr.Markdown("# Search and Export Items") + gr.Markdown("Search for items and export them as markdown files") + gr.Markdown("You can also export items by keyword") + search_query = gr.Textbox(label="Search Query") + search_type = gr.Radio(["Title", "URL", "Keyword", "Content"], label="Search By") + search_button = gr.Button("Search") + + with gr.Column(): + prev_button = gr.Button("Previous Page") + next_button = gr.Button("Next Page") + + current_page = gr.State(1) + total_pages = gr.State(1) + + search_results = gr.CheckboxGroup(label="Search Results", choices=[]) + export_selected_button = gr.Button("Export Selected Items") + + keyword_input = gr.Textbox(label="Enter keyword for export") + export_by_keyword_button = gr.Button("Export items by keyword") + + export_output = gr.File(label="Download Exported File") + error_output = gr.Textbox(label="Status/Error Messages", interactive=False) + + # Conversations Export Tab + with gr.Tab("RAG Conversations Export"): + with gr.Row(): + with gr.Column(): + gr.Markdown("## Export RAG QA Chat Conversations") + conversation_search = gr.Textbox(label="Search Conversations") + conversation_search_button = gr.Button("Search") + + with gr.Column(): + conversation_prev_button = gr.Button("Previous Page") + conversation_next_button = gr.Button("Next Page") + + conversation_current_page = gr.State(1) + conversation_total_pages = gr.State(1) + + conversation_results = gr.CheckboxGroup(label="Select Conversations to Export") + export_selected_conversations_button = gr.Button("Export Selected Conversations") + export_all_conversations_button = gr.Button("Export All Conversations") + + conversation_export_output = gr.File(label="Download Exported Conversations") + conversation_status = gr.Textbox(label="Status", interactive=False) + + # Notes Export Tab + with gr.Tab("RAG Notes Export"): + with gr.Row(): + with gr.Column(): + gr.Markdown("## Export RAG QA Chat Notes") + notes_search = gr.Textbox(label="Search Notes") + notes_search_button = gr.Button("Search") + + with gr.Column(): + notes_prev_button = gr.Button("Previous Page") + notes_next_button = gr.Button("Next Page") + + notes_current_page = gr.State(1) + notes_total_pages = gr.State(1) + + notes_results = gr.CheckboxGroup(label="Select Notes to Export") + export_selected_notes_button = gr.Button("Export Selected Notes") + export_all_notes_button = gr.Button("Export All Notes") + + notes_export_output = gr.File(label="Download Exported Notes") + notes_status = gr.Textbox(label="Status", interactive=False) + + # Event handlers for media DB + def search_and_update(query, search_type, page): + results, message, current, total = display_search_results_export_tab(query, search_type, page) + logger.debug(f"search_and_update results: {results}") + return results, message, current, total, gr.update(choices=results) + + def update_page(current, total, direction): + new_page = max(1, min(total, current + direction)) + return new_page + + def handle_export_selected(selected_items): + logger.debug(f"Exporting selected items: {selected_items}") + return export_selected_items(selected_items) + + def handle_item_selection(selected_items): + logger.debug(f"Selected items: {selected_items}") + if not selected_items: + return None, "No item selected" + + try: + selected_item = selected_items[0] + logger.debug(f"First selected item: {selected_item}") + + if isinstance(selected_item['value'], str): + item_data = json.loads(selected_item['value']) + else: + item_data = selected_item['value'] + + logger.debug(f"Item data: {item_data}") + item_id = item_data['id'] + return export_item_as_markdown(item_id) + except Exception as e: + error_message = f"Error processing selected item: {str(e)}" + logger.error(error_message) + return None, error_message + + search_button.click( + fn=search_and_update, + inputs=[search_query, search_type, current_page], + outputs=[search_results, error_output, current_page, total_pages, search_results], + show_progress="full" + ) + + prev_button.click( + fn=update_page, + inputs=[current_page, total_pages, gr.State(-1)], + outputs=[current_page] + ).then( + fn=search_and_update, + inputs=[search_query, search_type, current_page], + outputs=[search_results, error_output, current_page, total_pages], + show_progress=True + ) + + next_button.click( + fn=update_page, + inputs=[current_page, total_pages, gr.State(1)], + outputs=[current_page] + ).then( + fn=search_and_update, + inputs=[search_query, search_type, current_page], + outputs=[search_results, error_output, current_page, total_pages], + show_progress=True + ) + + export_selected_button.click( + fn=handle_export_selected, + inputs=[search_results], + outputs=[export_output, error_output], + show_progress="full" + ) + + export_by_keyword_button.click( + fn=export_items_by_keyword, + inputs=[keyword_input], + outputs=[export_output, error_output], + show_progress="full" + ) + + search_results.select( + fn=handle_item_selection, + inputs=[search_results], + outputs=[export_output, error_output], + show_progress="full" + ) + + # Event handlers for conversations + def search_conversations(query, page): + return display_rag_conversations(query, page) + + def update_conversation_page(current, total, direction): + new_page = max(1, min(total, current + direction)) + return new_page + + conversation_search_button.click( + fn=search_conversations, + inputs=[conversation_search, conversation_current_page], + outputs=[conversation_results, conversation_status, conversation_current_page, conversation_total_pages] + ) + + conversation_prev_button.click( + fn=update_conversation_page, + inputs=[conversation_current_page, conversation_total_pages, gr.State(-1)], + outputs=[conversation_current_page] + ).then( + fn=search_conversations, + inputs=[conversation_search, conversation_current_page], + outputs=[conversation_results, conversation_status, conversation_current_page, conversation_total_pages] + ) + + conversation_next_button.click( + fn=update_conversation_page, + inputs=[conversation_current_page, conversation_total_pages, gr.State(1)], + outputs=[conversation_current_page] + ).then( + fn=search_conversations, + inputs=[conversation_search, conversation_current_page], + outputs=[conversation_results, conversation_status, conversation_current_page, conversation_total_pages] + ) + + export_selected_conversations_button.click( + fn=export_rag_conversations_as_json, + inputs=[conversation_results], + outputs=[conversation_export_output, conversation_status] + ) + + export_all_conversations_button.click( + fn=lambda: export_rag_conversations_as_json(), + outputs=[conversation_export_output, conversation_status] + ) + + # Event handlers for notes + def search_notes(query, page): + return display_rag_notes(query, page) + + def update_notes_page(current, total, direction): + new_page = max(1, min(total, current + direction)) + return new_page + + notes_search_button.click( + fn=search_notes, + inputs=[notes_search, notes_current_page], + outputs=[notes_results, notes_status, notes_current_page, notes_total_pages] + ) + + notes_prev_button.click( + fn=update_notes_page, + inputs=[notes_current_page, notes_total_pages, gr.State(-1)], + outputs=[notes_current_page] + ).then( + fn=search_notes, + inputs=[notes_search, notes_current_page], + outputs=[notes_results, notes_status, notes_current_page, notes_total_pages] + ) + + notes_next_button.click( + fn=update_notes_page, + inputs=[notes_current_page, notes_total_pages, gr.State(1)], + outputs=[notes_current_page] + ).then( + fn=search_notes, + inputs=[notes_search, notes_current_page], + outputs=[notes_results, notes_status, notes_current_page, notes_total_pages] + ) + + export_selected_notes_button.click( + fn=export_rag_notes_as_json, + inputs=[notes_results], + outputs=[notes_export_output, notes_status] + ) + + export_all_notes_button.click( + fn=lambda: export_rag_notes_as_json(), + outputs=[notes_export_output, notes_status] + ) + + with gr.TabItem("Export Prompts", visible=True): + gr.Markdown("# Export Prompts Database Content") + + with gr.Row(): + with gr.Column(): + export_type = gr.Radio( + choices=["All Prompts", "Prompts by Keyword"], + label="Export Type", + value="All Prompts" + ) + + # Keyword selection for filtered export + with gr.Column(visible=False) as keyword_col: + keyword_input = gr.Textbox( + label="Enter Keywords (comma-separated)", + placeholder="Enter keywords to filter prompts..." + ) + + # Export format selection + export_format = gr.Radio( + choices=["CSV", "Markdown (ZIP)"], + label="Export Format", + value="CSV" + ) + + # Export options + include_options = gr.CheckboxGroup( + choices=[ + "Include System Prompts", + "Include User Prompts", + "Include Details", + "Include Author", + "Include Keywords" + ], + label="Export Options", + value=["Include Keywords", "Include Author"] + ) + + # Markdown-specific options (only visible when Markdown is selected) + with gr.Column(visible=False) as markdown_options_col: + markdown_template = gr.Radio( + choices=[ + "Basic Template", + "Detailed Template", + "Custom Template" + ], + label="Markdown Template", + value="Basic Template" + ) + custom_template = gr.Textbox( + label="Custom Template", + placeholder="Use {title}, {author}, {details}, {system}, {user}, {keywords} as placeholders", + visible=False + ) + + export_button = gr.Button("Export Prompts") + + with gr.Column(): + export_status = gr.Textbox(label="Export Status", interactive=False) + export_file = gr.File(label="Download Export") + + def update_ui_visibility(export_type, format_choice, template_choice): + """Update UI elements visibility based on selections""" + show_keywords = export_type == "Prompts by Keyword" + show_markdown_options = format_choice == "Markdown (ZIP)" + show_custom_template = template_choice == "Custom Template" and show_markdown_options + + return [ + gr.update(visible=show_keywords), # keyword_col + gr.update(visible=show_markdown_options), # markdown_options_col + gr.update(visible=show_custom_template) # custom_template + ] + + def handle_export(export_type, keywords, export_format, options, markdown_template, custom_template): + """Handle the export process based on selected options""" + try: + # Parse options + include_system = "Include System Prompts" in options + include_user = "Include User Prompts" in options + include_details = "Include Details" in options + include_author = "Include Author" in options + include_keywords = "Include Keywords" in options + + # Handle keyword filtering + keyword_list = None + if export_type == "Prompts by Keyword" and keywords: + keyword_list = [k.strip() for k in keywords.split(",") if k.strip()] + + # Get the appropriate template + template = None + if export_format == "Markdown (ZIP)": + if markdown_template == "Custom Template": + template = custom_template + else: + template = markdown_template + + # Perform export + from App_Function_Libraries.DB.Prompts_DB import export_prompts + status, file_path = export_prompts( + export_format=export_format.split()[0].lower(), # 'csv' or 'markdown' + filter_keywords=keyword_list, + include_system=include_system, + include_user=include_user, + include_details=include_details, + include_author=include_author, + include_keywords=include_keywords, + markdown_template=template + ) + + return status, file_path + + except Exception as e: + error_msg = f"Export failed: {str(e)}" + logging.error(error_msg) + return error_msg, None + + # Event handlers + export_type.change( + fn=lambda t, f, m: update_ui_visibility(t, f, m), + inputs=[export_type, export_format, markdown_template], + outputs=[keyword_col, markdown_options_col, custom_template] + ) + + export_format.change( + fn=lambda t, f, m: update_ui_visibility(t, f, m), + inputs=[export_type, export_format, markdown_template], + outputs=[keyword_col, markdown_options_col, custom_template] + ) + + markdown_template.change( + fn=lambda t, f, m: update_ui_visibility(t, f, m), + inputs=[export_type, export_format, markdown_template], + outputs=[keyword_col, markdown_options_col, custom_template] + ) + + export_button.click( + fn=handle_export, + inputs=[ + export_type, + keyword_input, + export_format, + include_options, + markdown_template, + custom_template + ], + outputs=[export_status, export_file] + ) + +# +# End of Export_Functionality.py +###################################################################################################################### diff --git a/App_Function_Libraries/Gradio_UI/Gradio_Shared.py b/App_Function_Libraries/Gradio_UI/Gradio_Shared.py index cfb2ea9c4..fd39a02cd 100644 --- a/App_Function_Libraries/Gradio_UI/Gradio_Shared.py +++ b/App_Function_Libraries/Gradio_UI/Gradio_Shared.py @@ -216,11 +216,6 @@ def format_content(content): return formatted_content -def update_prompt_dropdown(): - prompt_names = list_prompts() - return gr.update(choices=prompt_names) - - def display_prompt_details(selected_prompt): if selected_prompt: prompts = update_user_prompt(selected_prompt) diff --git a/App_Function_Libraries/Gradio_UI/Import_Functionality.py b/App_Function_Libraries/Gradio_UI/Import_Functionality.py index b73d974c4..37934d539 100644 --- a/App_Function_Libraries/Gradio_UI/Import_Functionality.py +++ b/App_Function_Libraries/Gradio_UI/Import_Functionality.py @@ -2,24 +2,31 @@ # Functionality to import content into the DB # # Imports +from datetime import datetime from time import sleep import logging import re import shutil import tempfile import os +from pathlib import Path +import sqlite3 import traceback +from typing import Optional, List, Dict, Tuple +import uuid import zipfile # # External Imports import gradio as gr +from chardet import detect + # # Local Imports -from App_Function_Libraries.DB.DB_Manager import insert_prompt_to_db, load_preset_prompts, import_obsidian_note_to_db, \ - add_media_to_database +from App_Function_Libraries.DB.DB_Manager import insert_prompt_to_db, import_obsidian_note_to_db, \ + add_media_to_database, list_prompts from App_Function_Libraries.Prompt_Handling import import_prompt_from_file, import_prompts_from_zip# from App_Function_Libraries.Summarization.Summarization_General_Lib import perform_summarization - +# ################################################################################################################### # # Functions: @@ -203,15 +210,6 @@ def save_prompt_to_db(title, author, system, user, keywords): outputs=save_output ) - def update_prompt_dropdown(): - return gr.update(choices=load_preset_prompts()) - - save_button.click( - fn=update_prompt_dropdown, - inputs=[], - outputs=[gr.Dropdown(label="Select Preset Prompt")] - ) - def create_import_item_tab(): with gr.TabItem("Import Markdown/Text Files", visible=True): gr.Markdown("# Import a markdown file or text file into the database") @@ -250,11 +248,18 @@ def create_import_multiple_prompts_tab(): gr.Markdown("# Import multiple prompts into the database") gr.Markdown("Upload a zip file containing multiple prompt files (txt or md)") + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + with gr.Row(): with gr.Column(): zip_file = gr.File(label="Upload zip file for import", file_types=["zip"]) import_button = gr.Button("Import Prompts") prompts_dropdown = gr.Dropdown(label="Select Prompt to Edit", choices=[]) + prev_page_button = gr.Button("Previous Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + next_page_button = gr.Button("Next Page", visible=False) title_input = gr.Textbox(label="Title", placeholder="Enter the title of the content") author_input = gr.Textbox(label="Author", placeholder="Enter the author's name") system_input = gr.Textbox(label="System", placeholder="Enter the system message for the prompt", @@ -268,6 +273,10 @@ def create_import_multiple_prompts_tab(): save_output = gr.Textbox(label="Save Status") prompts_display = gr.Textbox(label="Identified Prompts") + # State to store imported prompts + zip_import_state = gr.State([]) + + # Function to handle zip import def handle_zip_import(zip_file): result = import_prompts_from_zip(zip_file) if isinstance(result, list): @@ -278,6 +287,13 @@ def handle_zip_import(zip_file): else: return gr.update(value=result), [], gr.update(value=""), [] + import_button.click( + fn=handle_zip_import, + inputs=[zip_file], + outputs=[import_output, prompts_dropdown, prompts_display, zip_import_state] + ) + + # Function to handle prompt selection from imported prompts def handle_prompt_selection(selected_title, prompts): selected_prompt = next((prompt for prompt in prompts if prompt['title'] == selected_title), None) if selected_prompt: @@ -305,23 +321,68 @@ def handle_prompt_selection(selected_title, prompts): outputs=[title_input, author_input, system_input, user_input, keywords_input] ) + # Function to save prompt to the database def save_prompt_to_db(title, author, system, user, keywords): keyword_list = [k.strip() for k in keywords.split(',') if k.strip()] - return insert_prompt_to_db(title, author, system, user, keyword_list) + result = insert_prompt_to_db(title, author, system, user, keyword_list) + return result save_button.click( fn=save_prompt_to_db, inputs=[title_input, author_input, system_input, user_input, keywords_input], - outputs=save_output + outputs=[save_output] + ) + + # Adding pagination controls to navigate prompts in the database + def on_prev_page_click(current_page, total_pages): + new_page = max(current_page - 1, 1) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text), + current_page + ) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text), + current_page + ) + + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[prompts_dropdown, page_display, current_page_state] + ) + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[prompts_dropdown, page_display, current_page_state] ) + # Function to update prompts dropdown after saving to the database def update_prompt_dropdown(): - return gr.update(choices=load_preset_prompts()) + prompts, total_pages, current_page = list_prompts(page=1, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(visible=True), + gr.update(value=page_display_text, visible=True), + current_page, + total_pages + ) + # Update the dropdown after saving save_button.click( fn=update_prompt_dropdown, inputs=[], - outputs=[gr.Dropdown(label="Select Preset Prompt")] + outputs=[prompts_dropdown, prev_page_button, page_display, current_page_state, total_pages_state] ) @@ -385,4 +446,392 @@ def import_obsidian_vault(vault_path, progress=gr.Progress()): except Exception as e: error_msg = f"Error scanning vault: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) - return 0, 0, [error_msg] \ No newline at end of file + return 0, 0, [error_msg] + + +class RAGQABatchImporter: + def __init__(self, db_path: str): + self.db_path = Path(db_path) + self.setup_logging() + self.file_processor = FileProcessor() + self.zip_validator = ZipValidator() + + def setup_logging(self): + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler('rag_qa_import.log'), + logging.StreamHandler() + ] + ) + + def process_markdown_content(self, content: str) -> List[Dict[str, str]]: + """Process markdown content into a conversation format.""" + messages = [] + sections = content.split('\n\n') + + for section in sections: + if section.strip(): + messages.append({ + 'role': 'user', + 'content': section.strip() + }) + + return messages + + def process_keywords(self, db: sqlite3.Connection, conversation_id: str, keywords: str): + """Process and link keywords to a conversation.""" + if not keywords: + return + + keyword_list = [k.strip() for k in keywords.split(',')] + for keyword in keyword_list: + # Insert keyword if it doesn't exist + db.execute(""" + INSERT OR IGNORE INTO rag_qa_keywords (keyword) + VALUES (?) + """, (keyword,)) + + # Get keyword ID + keyword_id = db.execute(""" + SELECT id FROM rag_qa_keywords WHERE keyword = ? + """, (keyword,)).fetchone()[0] + + # Link keyword to conversation + db.execute(""" + INSERT INTO rag_qa_conversation_keywords + (conversation_id, keyword_id) + VALUES (?, ?) + """, (conversation_id, keyword_id)) + + def import_single_file( + self, + db: sqlite3.Connection, + content: str, + filename: str, + keywords: str, + custom_prompt: Optional[str] = None, + rating: Optional[int] = None + ) -> str: + """Import a single file's content into the database""" + conversation_id = str(uuid.uuid4()) + current_time = datetime.now().isoformat() + + # Process filename into title + title = FileProcessor.process_filename_to_title(filename) + if title.lower().endswith(('.md', '.txt')): + title = title[:-3] if title.lower().endswith('.md') else title[:-4] + + # Insert conversation metadata + db.execute(""" + INSERT INTO conversation_metadata + (conversation_id, created_at, last_updated, title, rating) + VALUES (?, ?, ?, ?, ?) + """, (conversation_id, current_time, current_time, title, rating)) + + # Process content and insert messages + messages = self.process_markdown_content(content) + for msg in messages: + db.execute(""" + INSERT INTO rag_qa_chats + (conversation_id, timestamp, role, content) + VALUES (?, ?, ?, ?) + """, (conversation_id, current_time, msg['role'], msg['content'])) + + # Process keywords + self.process_keywords(db, conversation_id, keywords) + + return conversation_id + + def extract_zip(self, zip_path: str) -> List[Tuple[str, str]]: + """Extract and validate files from zip""" + is_valid, error_msg, valid_files = self.zip_validator.validate_zip_file(zip_path) + if not is_valid: + raise ValueError(error_msg) + + files = [] + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + for filename in valid_files: + with zip_ref.open(filename) as f: + content = f.read() + # Try to decode with detected encoding + try: + detected_encoding = detect(content)['encoding'] or 'utf-8' + content = content.decode(detected_encoding) + except UnicodeDecodeError: + content = content.decode('utf-8', errors='replace') + + filename = os.path.basename(filename) + files.append((filename, content)) + return files + + def import_files( + self, + files: List[str], + keywords: str = "", + custom_prompt: Optional[str] = None, + rating: Optional[int] = None, + progress=gr.Progress() + ) -> Tuple[bool, str]: + """Import multiple files or zip files into the RAG QA database.""" + try: + imported_files = [] + + with sqlite3.connect(self.db_path) as db: + # Process each file + for file_path in progress.tqdm(files, desc="Processing files"): + filename = os.path.basename(file_path) + + # Handle zip files + if filename.lower().endswith('.zip'): + zip_files = self.extract_zip(file_path) + for zip_filename, content in progress.tqdm(zip_files, desc=f"Processing files from {filename}"): + conv_id = self.import_single_file( + db=db, + content=content, + filename=zip_filename, + keywords=keywords, + custom_prompt=custom_prompt, + rating=rating + ) + imported_files.append(zip_filename) + + # Handle individual markdown/text files + elif filename.lower().endswith(('.md', '.txt')): + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + conv_id = self.import_single_file( + db=db, + content=content, + filename=filename, + keywords=keywords, + custom_prompt=custom_prompt, + rating=rating + ) + imported_files.append(filename) + + db.commit() + + return True, f"Successfully imported {len(imported_files)} files:\n" + "\n".join(imported_files) + + except Exception as e: + logging.error(f"Import failed: {str(e)}") + return False, f"Import failed: {str(e)}" + + +class FileProcessor: + """Handles file reading and name processing""" + + VALID_EXTENSIONS = {'.md', '.txt', '.zip'} + ENCODINGS_TO_TRY = [ + 'utf-8', + 'utf-16', + 'windows-1252', + 'iso-8859-1', + 'ascii' + ] + + @staticmethod + def detect_encoding(file_path: str) -> str: + """Detect the file encoding using chardet""" + with open(file_path, 'rb') as file: + raw_data = file.read() + result = detect(raw_data) + return result['encoding'] or 'utf-8' + + @staticmethod + def read_file_content(file_path: str) -> str: + """Read file content with automatic encoding detection""" + detected_encoding = FileProcessor.detect_encoding(file_path) + + # Try detected encoding first + try: + with open(file_path, 'r', encoding=detected_encoding) as f: + return f.read() + except UnicodeDecodeError: + # If detected encoding fails, try others + for encoding in FileProcessor.ENCODINGS_TO_TRY: + try: + with open(file_path, 'r', encoding=encoding) as f: + return f.read() + except UnicodeDecodeError: + continue + + # If all encodings fail, use utf-8 with error handling + with open(file_path, 'r', encoding='utf-8', errors='replace') as f: + return f.read() + + @staticmethod + def process_filename_to_title(filename: str) -> str: + """Convert filename to a readable title""" + # Remove extension + name = os.path.splitext(filename)[0] + + # Look for date patterns + date_pattern = r'(\d{4}[-_]?\d{2}[-_]?\d{2})' + date_match = re.search(date_pattern, name) + date_str = "" + if date_match: + try: + date = datetime.strptime(date_match.group(1).replace('_', '-'), '%Y-%m-%d') + date_str = date.strftime("%b %d, %Y") + name = name.replace(date_match.group(1), '').strip('-_') + except ValueError: + pass + + # Replace separators with spaces + name = re.sub(r'[-_]+', ' ', name) + + # Remove redundant spaces + name = re.sub(r'\s+', ' ', name).strip() + + # Capitalize words, excluding certain words + exclude_words = {'a', 'an', 'the', 'in', 'on', 'at', 'to', 'for', 'of', 'with'} + words = name.split() + capitalized = [] + for i, word in enumerate(words): + if i == 0 or word not in exclude_words: + capitalized.append(word.capitalize()) + else: + capitalized.append(word.lower()) + name = ' '.join(capitalized) + + # Add date if found + if date_str: + name = f"{name} - {date_str}" + + return name + + +class ZipValidator: + """Validates zip file contents and structure""" + + MAX_ZIP_SIZE = 100 * 1024 * 1024 # 100MB + MAX_FILES = 100 + VALID_EXTENSIONS = {'.md', '.txt'} + + @staticmethod + def validate_zip_file(zip_path: str) -> Tuple[bool, str, List[str]]: + """ + Validate zip file and its contents + Returns: (is_valid, error_message, valid_files) + """ + try: + # Check zip file size + if os.path.getsize(zip_path) > ZipValidator.MAX_ZIP_SIZE: + return False, "Zip file too large (max 100MB)", [] + + valid_files = [] + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + # Check number of files + if len(zip_ref.filelist) > ZipValidator.MAX_FILES: + return False, f"Too many files in zip (max {ZipValidator.MAX_FILES})", [] + + # Check for directory traversal attempts + for file_info in zip_ref.filelist: + if '..' in file_info.filename or file_info.filename.startswith('/'): + return False, "Invalid file paths detected", [] + + # Validate each file + total_size = 0 + for file_info in zip_ref.filelist: + # Skip directories + if file_info.filename.endswith('/'): + continue + + # Check file size + if file_info.file_size > ZipValidator.MAX_ZIP_SIZE: + return False, f"File {file_info.filename} too large", [] + + total_size += file_info.file_size + if total_size > ZipValidator.MAX_ZIP_SIZE: + return False, "Total uncompressed size too large", [] + + # Check file extension + ext = os.path.splitext(file_info.filename)[1].lower() + if ext in ZipValidator.VALID_EXTENSIONS: + valid_files.append(file_info.filename) + + if not valid_files: + return False, "No valid markdown or text files found in zip", [] + + return True, "", valid_files + + except zipfile.BadZipFile: + return False, "Invalid or corrupted zip file", [] + except Exception as e: + return False, f"Error processing zip file: {str(e)}", [] + + +def create_conversation_import_tab() -> gr.Tab: + """Create the import tab for the Gradio interface""" + with gr.Tab("Import RAG Chats") as tab: + gr.Markdown("# Import RAG Chats into the Database") + gr.Markdown(""" + Import your RAG Chat markdown/text files individually or as a zip archive + + Supported file types: + - Markdown (.md) + - Text (.txt) + - Zip archives containing .md or .txt files + + Maximum zip file size: 100MB + Maximum files per zip: 100 + """) + with gr.Row(): + with gr.Column(): + import_files = gr.File( + label="Upload Files", + file_types=["txt", "md", "zip"], + file_count="multiple" + ) + + keywords_input = gr.Textbox( + label="Keywords", + placeholder="Enter keywords to apply to all imported files (comma-separated)" + ) + + custom_prompt_input = gr.Textbox( + label="Custom Prompt", + placeholder="Enter a custom prompt for processing (optional)" + ) + + rating_input = gr.Slider( + minimum=1, + maximum=3, + step=1, + label="Rating (1-3)", + value=None + ) + + with gr.Column(): + import_button = gr.Button("Import Files") + import_output = gr.Textbox( + label="Import Status", + lines=10 + ) + + def handle_import(files, keywords, custom_prompt, rating): + importer = RAGQABatchImporter("rag_qa.db") # Update with your DB path + success, message = importer.import_files( + files=[f.name for f in files], + keywords=keywords, + custom_prompt=custom_prompt, + rating=rating + ) + return message + + import_button.click( + fn=handle_import, + inputs=[ + import_files, + keywords_input, + custom_prompt_input, + rating_input + ], + outputs=import_output + ) + + return tab diff --git a/App_Function_Libraries/Gradio_UI/Keywords.py b/App_Function_Libraries/Gradio_UI/Keywords.py index b2c7a213b..dd2592682 100644 --- a/App_Function_Libraries/Gradio_UI/Keywords.py +++ b/App_Function_Libraries/Gradio_UI/Keywords.py @@ -4,22 +4,29 @@ # The Keywords tab allows the user to add, delete, view, and export keywords from the database. # # Imports: - # # External Imports import gradio as gr + +from App_Function_Libraries.DB.Character_Chat_DB import view_char_keywords, add_char_keywords, delete_char_keyword, \ + export_char_keywords_to_csv # # Internal Imports from App_Function_Libraries.DB.DB_Manager import add_keyword, delete_keyword, keywords_browser_interface, export_keywords_to_csv -# +from App_Function_Libraries.DB.Prompts_DB import view_prompt_keywords, delete_prompt_keyword, \ + export_prompt_keywords_to_csv +from App_Function_Libraries.DB.RAG_QA_Chat_DB import view_rag_keywords, get_all_collections, \ + get_keywords_for_collection, create_keyword_collection, add_keyword_to_collection, delete_rag_keyword, \ + export_rag_keywords_to_csv + + # ###################################################################################################################### # # Functions: - def create_export_keywords_tab(): - with gr.TabItem("Export Keywords", visible=True): + with gr.TabItem("Export MediaDB Keywords", visible=True): with gr.Row(): with gr.Column(): export_keywords_button = gr.Button("Export Keywords") @@ -33,8 +40,8 @@ def create_export_keywords_tab(): ) def create_view_keywords_tab(): - with gr.TabItem("View Keywords", visible=True): - gr.Markdown("# Browse Keywords") + with gr.TabItem("View MediaDB Keywords", visible=True): + gr.Markdown("# Browse MediaDB Keywords") with gr.Column(): browse_output = gr.Markdown() browse_button = gr.Button("View Existing Keywords") @@ -42,7 +49,7 @@ def create_view_keywords_tab(): def create_add_keyword_tab(): - with gr.TabItem("Add Keywords", visible=True): + with gr.TabItem("Add MediaDB Keywords", visible=True): with gr.Row(): with gr.Column(): gr.Markdown("# Add Keywords to the Database") @@ -54,7 +61,7 @@ def create_add_keyword_tab(): def create_delete_keyword_tab(): - with gr.Tab("Delete Keywords", visible=True): + with gr.Tab("Delete MediaDB Keywords", visible=True): with gr.Row(): with gr.Column(): gr.Markdown("# Delete Keywords from the Database") @@ -63,3 +70,289 @@ def create_delete_keyword_tab(): with gr.Row(): delete_output = gr.Textbox(label="Result") delete_button.click(fn=delete_keyword, inputs=delete_input, outputs=delete_output) + +# +# End of Media DB Keyword tabs +########################################################## + + +############################################################ +# +# Character DB Keyword functions + +def create_character_keywords_tab(): + """Creates the Character Keywords management tab""" + with gr.Tab("Character Keywords"): + gr.Markdown("# Character Keywords Management") + + with gr.Tabs(): + # View Character Keywords Tab + with gr.TabItem("View Keywords"): + with gr.Column(): + refresh_char_keywords = gr.Button("Refresh Character Keywords") + char_keywords_output = gr.Markdown() + view_char_keywords() + refresh_char_keywords.click( + fn=view_char_keywords, + outputs=char_keywords_output + ) + + # Add Character Keywords Tab + with gr.TabItem("Add Keywords"): + with gr.Column(): + char_name = gr.Textbox(label="Character Name") + new_keywords = gr.Textbox(label="New Keywords (comma-separated)") + add_char_keyword_btn = gr.Button("Add Keywords") + add_char_result = gr.Markdown() + + add_char_keyword_btn.click( + fn=add_char_keywords, + inputs=[char_name, new_keywords], + outputs=add_char_result + ) + + # Delete Character Keywords Tab (New) + with gr.TabItem("Delete Keywords"): + with gr.Column(): + delete_char_name = gr.Textbox(label="Character Name") + delete_char_keyword_input = gr.Textbox(label="Keyword to Delete") + delete_char_keyword_btn = gr.Button("Delete Keyword") + delete_char_result = gr.Markdown() + + delete_char_keyword_btn.click( + fn=delete_char_keyword, + inputs=[delete_char_name, delete_char_keyword_input], + outputs=delete_char_result + ) + + # Export Character Keywords Tab (New) + with gr.TabItem("Export Keywords"): + with gr.Column(): + export_char_keywords_btn = gr.Button("Export Character Keywords") + export_char_file = gr.File(label="Download Exported Keywords") + export_char_status = gr.Textbox(label="Export Status") + + export_char_keywords_btn.click( + fn=export_char_keywords_to_csv, + outputs=[export_char_status, export_char_file] + ) + +# +# End of Character Keywords tab +########################################################## + +############################################################ +# +# RAG QA Keywords functions + +def create_rag_qa_keywords_tab(): + """Creates the RAG QA Keywords management tab""" + with gr.Tab("RAG QA Keywords"): + gr.Markdown("# RAG QA Keywords Management") + + with gr.Tabs(): + # View RAG QA Keywords Tab + with gr.TabItem("View Keywords"): + with gr.Column(): + refresh_rag_keywords = gr.Button("Refresh RAG QA Keywords") + rag_keywords_output = gr.Markdown() + + view_rag_keywords() + + refresh_rag_keywords.click( + fn=view_rag_keywords, + outputs=rag_keywords_output + ) + + # Add RAG QA Keywords Tab + with gr.TabItem("Add Keywords"): + with gr.Column(): + new_rag_keywords = gr.Textbox(label="New Keywords (comma-separated)") + add_rag_keyword_btn = gr.Button("Add Keywords") + add_rag_result = gr.Markdown() + + add_rag_keyword_btn.click( + fn=add_keyword, + inputs=new_rag_keywords, + outputs=add_rag_result + ) + + # Delete RAG QA Keywords Tab (New) + with gr.TabItem("Delete Keywords"): + with gr.Column(): + delete_rag_keyword_input = gr.Textbox(label="Keyword to Delete") + delete_rag_keyword_btn = gr.Button("Delete Keyword") + delete_rag_result = gr.Markdown() + + delete_rag_keyword_btn.click( + fn=delete_rag_keyword, + inputs=delete_rag_keyword_input, + outputs=delete_rag_result + ) + + # Export RAG QA Keywords Tab (New) + with gr.TabItem("Export Keywords"): + with gr.Column(): + export_rag_keywords_btn = gr.Button("Export RAG QA Keywords") + export_rag_file = gr.File(label="Download Exported Keywords") + export_rag_status = gr.Textbox(label="Export Status") + + export_rag_keywords_btn.click( + fn=export_rag_keywords_to_csv, + outputs=[export_rag_status, export_rag_file] + ) + +# +# End of RAG QA Keywords tab +########################################################## + + +############################################################ +# +# Prompt Keywords functions + +def create_prompt_keywords_tab(): + """Creates the Prompt Keywords management tab""" + with gr.Tab("Prompt Keywords"): + gr.Markdown("# Prompt Keywords Management") + + with gr.Tabs(): + # View Keywords Tab + with gr.TabItem("View Keywords"): + with gr.Column(): + refresh_prompt_keywords = gr.Button("Refresh Prompt Keywords") + prompt_keywords_output = gr.Markdown() + + refresh_prompt_keywords.click( + fn=view_prompt_keywords, + outputs=prompt_keywords_output + ) + + # Add Keywords Tab (using existing prompt management functions) + with gr.TabItem("Add Keywords"): + gr.Markdown(""" + To add keywords to prompts, please use the Prompt Management interface. + Keywords can be added when creating or editing a prompt. + """) + + # Delete Keywords Tab + with gr.TabItem("Delete Keywords"): + with gr.Column(): + delete_prompt_keyword_input = gr.Textbox(label="Keyword to Delete") + delete_prompt_keyword_btn = gr.Button("Delete Keyword") + delete_prompt_result = gr.Markdown() + + delete_prompt_keyword_btn.click( + fn=delete_prompt_keyword, + inputs=delete_prompt_keyword_input, + outputs=delete_prompt_result + ) + + # Export Keywords Tab + with gr.TabItem("Export Keywords"): + with gr.Column(): + export_prompt_keywords_btn = gr.Button("Export Prompt Keywords") + export_prompt_status = gr.Textbox(label="Export Status", interactive=False) + export_prompt_file = gr.File(label="Download Exported Keywords", interactive=False) + + def handle_export(): + status, file_path = export_prompt_keywords_to_csv() + if file_path: + return status, file_path + return status, None + + export_prompt_keywords_btn.click( + fn=handle_export, + outputs=[export_prompt_status, export_prompt_file] + ) +# +# End of Prompt Keywords tab +############################################################ + + +############################################################ +# +# Meta-Keywords functions + +def create_meta_keywords_tab(): + """Creates the Meta-Keywords management tab""" + with gr.Tab("Meta-Keywords"): + gr.Markdown("# Meta-Keywords Management") + + with gr.Tabs(): + # View Meta-Keywords Tab + with gr.TabItem("View Collections"): + with gr.Column(): + refresh_collections = gr.Button("Refresh Collections") + collections_output = gr.Markdown() + + def view_collections(): + try: + collections, _, _ = get_all_collections() + if collections: + result = "### Keyword Collections:\n" + for collection in collections: + keywords = get_keywords_for_collection(collection) + result += f"\n**{collection}**:\n" + result += "\n".join([f"- {k}" for k in keywords]) + result += "\n" + return result + return "No collections found." + except Exception as e: + return f"Error retrieving collections: {str(e)}" + + refresh_collections.click( + fn=view_collections, + outputs=collections_output + ) + + # Create Collection Tab + with gr.TabItem("Create Collection"): + with gr.Column(): + collection_name = gr.Textbox(label="Collection Name") + create_collection_btn = gr.Button("Create Collection") + create_result = gr.Markdown() + + def create_collection(name: str): + try: + create_keyword_collection(name) + return f"Successfully created collection: {name}" + except Exception as e: + return f"Error creating collection: {str(e)}" + + create_collection_btn.click( + fn=create_collection, + inputs=collection_name, + outputs=create_result + ) + + # Add Keywords to Collection Tab + with gr.TabItem("Add to Collection"): + with gr.Column(): + collection_select = gr.Textbox(label="Collection Name") + keywords_to_add = gr.Textbox(label="Keywords to Add (comma-separated)") + add_to_collection_btn = gr.Button("Add Keywords to Collection") + add_to_collection_result = gr.Markdown() + + def add_keywords_to_collection(collection: str, keywords: str): + try: + keywords_list = [k.strip() for k in keywords.split(",") if k.strip()] + for keyword in keywords_list: + add_keyword_to_collection(collection, keyword) + return f"Successfully added {len(keywords_list)} keywords to collection {collection}" + except Exception as e: + return f"Error adding keywords to collection: {str(e)}" + + add_to_collection_btn.click( + fn=add_keywords_to_collection, + inputs=[collection_select, keywords_to_add], + outputs=add_to_collection_result + ) + +# +# End of Meta-Keywords tab +########################################################## + +# +# End of Keywords.py +###################################################################################################################### diff --git a/App_Function_Libraries/Gradio_UI/Media_edit.py b/App_Function_Libraries/Gradio_UI/Media_edit.py index 0912736b4..4cbfdd2ef 100644 --- a/App_Function_Libraries/Gradio_UI/Media_edit.py +++ b/App_Function_Libraries/Gradio_UI/Media_edit.py @@ -10,13 +10,13 @@ # # Local Imports from App_Function_Libraries.DB.DB_Manager import add_prompt, update_media_content, db, add_or_update_prompt, \ - load_prompt_details, fetch_keywords_for_media, update_keywords_for_media -from App_Function_Libraries.Gradio_UI.Gradio_Shared import update_dropdown, update_prompt_dropdown + load_prompt_details, fetch_keywords_for_media, update_keywords_for_media, fetch_prompt_details, list_prompts +from App_Function_Libraries.Gradio_UI.Gradio_Shared import update_dropdown from App_Function_Libraries.DB.SQLite_DB import fetch_item_details def create_media_edit_tab(): - with gr.TabItem("Edit Existing Items", visible=True): + with gr.TabItem("Edit Existing Items in the Media DB", visible=True): gr.Markdown("# Search and Edit Media Items") with gr.Row(): @@ -89,7 +89,7 @@ def update_media_with_keywords(selected_item, item_mapping, content, prompt, sum def create_media_edit_and_clone_tab(): - with gr.TabItem("Clone and Edit Existing Items", visible=True): + with gr.TabItem("Clone and Edit Existing Items in the Media DB", visible=True): gr.Markdown("# Search, Edit, and Clone Existing Items") with gr.Row(): @@ -199,6 +199,11 @@ def save_cloned_item(selected_item, item_mapping, content, prompt, summary, new_ def create_prompt_edit_tab(): + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + per_page = 10 # Number of prompts per page + with gr.TabItem("Add & Edit Prompts", visible=True): with gr.Row(): with gr.Column(): @@ -207,38 +212,145 @@ def create_prompt_edit_tab(): choices=[], interactive=True ) + next_page_button = gr.Button("Next Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + prev_page_button = gr.Button("Previous Page", visible=False) prompt_list_button = gr.Button("List Prompts") with gr.Column(): title_input = gr.Textbox(label="Title", placeholder="Enter the prompt title") - author_input = gr.Textbox(label="Author", placeholder="Enter the prompt's author", lines=3) + author_input = gr.Textbox(label="Author", placeholder="Enter the prompt's author", lines=1) description_input = gr.Textbox(label="Description", placeholder="Enter the prompt description", lines=3) system_prompt_input = gr.Textbox(label="System Prompt", placeholder="Enter the system prompt", lines=3) user_prompt_input = gr.Textbox(label="User Prompt", placeholder="Enter the user prompt", lines=3) add_prompt_button = gr.Button("Add/Update Prompt") add_prompt_output = gr.HTML() - # Event handlers + # Function to update the prompt dropdown with pagination + def update_prompt_dropdown(page=1): + prompts, total_pages, current_page = list_prompts(page=page, per_page=per_page) + page_display_text = f"Page {current_page} of {total_pages}" + prev_button_visible = current_page > 1 + next_button_visible = current_page < total_pages + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text, visible=True), + gr.update(visible=prev_button_visible), + gr.update(visible=next_button_visible), + current_page, + total_pages + ) + + # Event handler for listing prompts prompt_list_button.click( fn=update_prompt_dropdown, - outputs=prompt_dropdown + inputs=[], + outputs=[ + prompt_dropdown, + page_display, + prev_page_button, + next_page_button, + current_page_state, + total_pages_state + ] + ) + + # Functions to handle pagination + def on_prev_page_click(current_page): + new_page = max(current_page - 1, 1) + return update_prompt_dropdown(page=new_page) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + return update_prompt_dropdown(page=new_page) + + # Event handlers for pagination buttons + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state], + outputs=[ + prompt_dropdown, + page_display, + prev_page_button, + next_page_button, + current_page_state, + total_pages_state + ] + ) + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[ + prompt_dropdown, + page_display, + prev_page_button, + next_page_button, + current_page_state, + total_pages_state + ] ) + # Event handler for adding or updating a prompt add_prompt_button.click( fn=add_or_update_prompt, inputs=[title_input, author_input, description_input, system_prompt_input, user_prompt_input], - outputs=add_prompt_output + outputs=[add_prompt_output] + ).then( + fn=update_prompt_dropdown, + inputs=[], + outputs=[ + prompt_dropdown, + page_display, + prev_page_button, + next_page_button, + current_page_state, + total_pages_state + ] ) - # Load prompt details when selected + # Function to load prompt details when a prompt is selected + def load_prompt_details(selected_prompt): + details = fetch_prompt_details(selected_prompt) + if details: + title, author, description, system_prompt, user_prompt, keywords = details + return ( + gr.update(value=title), + gr.update(value=author or ""), + gr.update(value=description or ""), + gr.update(value=system_prompt or ""), + gr.update(value=user_prompt or "") + ) + else: + return ( + gr.update(value=""), + gr.update(value=""), + gr.update(value=""), + gr.update(value=""), + gr.update(value="") + ) + + # Event handler for prompt selection change prompt_dropdown.change( fn=load_prompt_details, inputs=[prompt_dropdown], - outputs=[title_input, author_input, system_prompt_input, user_prompt_input] + outputs=[ + title_input, + author_input, + description_input, + system_prompt_input, + user_prompt_input + ] ) + def create_prompt_clone_tab(): + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + per_page = 10 # Number of prompts per page + with gr.TabItem("Clone and Edit Prompts", visible=True): with gr.Row(): with gr.Column(): @@ -248,6 +360,9 @@ def create_prompt_clone_tab(): choices=[], interactive=True ) + next_page_button = gr.Button("Next Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + prev_page_button = gr.Button("Previous Page", visible=False) prompt_list_button = gr.Button("List Prompts") with gr.Column(): @@ -260,19 +375,99 @@ def create_prompt_clone_tab(): save_cloned_prompt_button = gr.Button("Save Cloned Prompt", visible=False) add_prompt_output = gr.HTML() - # Event handlers + # Function to update the prompt dropdown with pagination + def update_prompt_dropdown(page=1): + prompts, total_pages, current_page = list_prompts(page=page, per_page=per_page) + page_display_text = f"Page {current_page} of {total_pages}" + prev_button_visible = current_page > 1 + next_button_visible = current_page < total_pages + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text, visible=True), + gr.update(visible=prev_button_visible), + gr.update(visible=next_button_visible), + current_page, + total_pages + ) + + # Event handler for listing prompts prompt_list_button.click( fn=update_prompt_dropdown, - outputs=prompt_dropdown + inputs=[], + outputs=[ + prompt_dropdown, + page_display, + prev_page_button, + next_page_button, + current_page_state, + total_pages_state + ] + ) + + # Functions to handle pagination + def on_prev_page_click(current_page): + new_page = max(current_page - 1, 1) + return update_prompt_dropdown(page=new_page) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + return update_prompt_dropdown(page=new_page) + + # Event handlers for pagination buttons + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state], + outputs=[ + prompt_dropdown, + page_display, + prev_page_button, + next_page_button, + current_page_state, + total_pages_state + ] + ) + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[ + prompt_dropdown, + page_display, + prev_page_button, + next_page_button, + current_page_state, + total_pages_state + ] ) # Load prompt details when selected + def load_prompt_details(selected_prompt): + if selected_prompt: + details = fetch_prompt_details(selected_prompt) + if details: + title, author, description, system_prompt, user_prompt, keywords = details + return ( + gr.update(value=title), + gr.update(value=author or ""), + gr.update(value=description or ""), + gr.update(value=system_prompt or ""), + gr.update(value=user_prompt or "") + ) + return ( + gr.update(value=""), + gr.update(value=""), + gr.update(value=""), + gr.update(value=""), + gr.update(value="") + ) + prompt_dropdown.change( fn=load_prompt_details, inputs=[prompt_dropdown], outputs=[title_input, author_input, description_input, system_prompt_input, user_prompt_input] ) + # Prepare for cloning def prepare_for_cloning(selected_prompt): if selected_prompt: return gr.update(value=f"Copy of {selected_prompt}"), gr.update(visible=True) @@ -284,18 +479,21 @@ def prepare_for_cloning(selected_prompt): outputs=[title_input, save_cloned_prompt_button] ) - def save_cloned_prompt(title, description, system_prompt, user_prompt): + # Function to save cloned prompt + def save_cloned_prompt(title, author, description, system_prompt, user_prompt, current_page): try: - result = add_prompt(title, description, system_prompt, user_prompt) + result = add_prompt(title, author, description, system_prompt, user_prompt) if result == "Prompt added successfully.": - return result, gr.update(choices=update_prompt_dropdown()) + # After adding, refresh the prompt dropdown + prompt_dropdown_update = update_prompt_dropdown(page=current_page) + return (result, *prompt_dropdown_update) else: - return result, gr.update() + return (result, gr.update(), gr.update(), gr.update(), gr.update(), current_page, total_pages_state.value) except Exception as e: - return f"Error saving cloned prompt: {str(e)}", gr.update() + return (f"Error saving cloned prompt: {str(e)}", gr.update(), gr.update(), gr.update(), gr.update(), current_page, total_pages_state.value) save_cloned_prompt_button.click( fn=save_cloned_prompt, - inputs=[title_input, description_input, system_prompt_input, user_prompt_input], - outputs=[add_prompt_output, prompt_dropdown] - ) \ No newline at end of file + inputs=[title_input, author_input, description_input, system_prompt_input, user_prompt_input, current_page_state], + outputs=[add_prompt_output, prompt_dropdown, page_display, prev_page_button, next_page_button, current_page_state, total_pages_state] + ) diff --git a/App_Function_Libraries/Gradio_UI/Mind_Map_tab.py b/App_Function_Libraries/Gradio_UI/Mind_Map_tab.py new file mode 100644 index 000000000..5666b8fc1 --- /dev/null +++ b/App_Function_Libraries/Gradio_UI/Mind_Map_tab.py @@ -0,0 +1,128 @@ +# Mind_Map_tab.py +# Description: File contains functions for generation of PlantUML mindmaps for the gradio tab +# +# Imports +import re +# +# External Libraries +import gradio as gr +# +###################################################################################################################### +# +# Functions: + +def parse_plantuml_mindmap(plantuml_text: str) -> dict: + """Parse PlantUML mindmap syntax into a nested dictionary structure""" + lines = [line.strip() for line in plantuml_text.split('\n') + if line.strip() and not line.strip().startswith('@')] + + root = None + nodes = [] + stack = [] + + for line in lines: + level_match = re.match(r'^([+\-*]+|\*+)', line) + if not level_match: + continue + level = len(level_match.group(0)) + text = re.sub(r'^([+\-*]+|\*+)\s*', '', line).strip('[]').strip('()') + node = {'text': text, 'children': []} + + while stack and stack[-1][0] >= level: + stack.pop() + + if stack: + stack[-1][1]['children'].append(node) + else: + root = node + + stack.append((level, node)) + + return root + +def create_mindmap_html(plantuml_text: str) -> str: + """Convert PlantUML mindmap to HTML visualization with collapsible nodes using CSS only""" + # Parse the mindmap text into a nested structure + root_node = parse_plantuml_mindmap(plantuml_text) + if not root_node: + return "

No valid mindmap content provided.

" + + html = "" + + colors = ['#e6f3ff', '#f0f7ff', '#f5f5f5', '#fff0f0', '#f0fff0'] + + def create_node_html(node, level): + bg_color = colors[(level - 1) % len(colors)] + if node['children']: + children_html = ''.join(create_node_html(child, level + 1) for child in node['children']) + return f""" +
+ {node['text']} + {children_html} +
+ """ + else: + return f""" +
+ {node['text']} +
+ """ + + html += create_node_html(root_node, level=1) + return html + +# Create Gradio interface +def create_mindmap_tab(): + with gr.TabItem("PlantUML Mindmap"): + gr.Markdown("# Collapsible PlantUML Mindmap Visualizer") + gr.Markdown("Convert PlantUML mindmap syntax to a visual mindmap with collapsible nodes.") + plantuml_input = gr.Textbox( + lines=15, + label="Enter PlantUML mindmap", + placeholder="""@startmindmap + * Project Planning + ** Requirements + *** Functional Requirements + **** User Interface + **** Backend Services + *** Technical Requirements + **** Performance + **** Security + ** Timeline + *** Phase 1 + *** Phase 2 + ** Resources + *** Team + *** Budget + @endmindmap""" + ) + submit_btn = gr.Button("Generate Mindmap") + mindmap_output = gr.HTML(label="Mindmap Output") + submit_btn.click( + fn=create_mindmap_html, + inputs=plantuml_input, + outputs=mindmap_output + ) + +# +# End of Mind_Map_tab.py +###################################################################################################################### diff --git a/App_Function_Libraries/Gradio_UI/PDF_ingestion_tab.py b/App_Function_Libraries/Gradio_UI/PDF_ingestion_tab.py index 25c5ba6ec..04381ab13 100644 --- a/App_Function_Libraries/Gradio_UI/PDF_ingestion_tab.py +++ b/App_Function_Libraries/Gradio_UI/PDF_ingestion_tab.py @@ -10,7 +10,7 @@ import gradio as gr # # Local Imports -from App_Function_Libraries.DB.DB_Manager import load_preset_prompts +from App_Function_Libraries.DB.DB_Manager import list_prompts from App_Function_Libraries.Gradio_UI.Chat_ui import update_user_prompt from App_Function_Libraries.PDF.PDF_Ingestion_Lib import extract_metadata_from_pdf, extract_text_and_format_from_pdf, \ process_and_cleanup_pdf @@ -22,87 +22,218 @@ def create_pdf_ingestion_tab(): with gr.TabItem("PDF Ingestion", visible=True): - # TODO - Add functionality to extract metadata from pdf as part of conversion process in marker gr.Markdown("# Ingest PDF Files and Extract Metadata") with gr.Row(): with gr.Column(): - pdf_file_input = gr.File(label="Uploaded PDF File", file_types=[".pdf"], visible=True) - pdf_upload_button = gr.UploadButton("Click to Upload PDF", file_types=[".pdf"]) - pdf_title_input = gr.Textbox(label="Title (Optional)") - pdf_author_input = gr.Textbox(label="Author (Optional)") - pdf_keywords_input = gr.Textbox(label="Keywords (Optional, comma-separated)") - with gr.Row(): - custom_prompt_checkbox = gr.Checkbox(label="Use a Custom Prompt", - value=False, - visible=True) - preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", - value=False, - visible=True) - with gr.Row(): - preset_prompt = gr.Dropdown(label="Select Preset Prompt", - choices=load_preset_prompts(), - visible=False) - with gr.Row(): - custom_prompt_input = gr.Textbox(label="Custom Prompt", - placeholder="Enter custom prompt here", - lines=3, - visible=False) - with gr.Row(): - system_prompt_input = gr.Textbox(label="System Prompt", - value=""" -You are a bulleted notes specialist. -[INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] -**Bulleted Note Creation Guidelines** - -**Headings**: -- Based on referenced topics, not categories like quotes or terms -- Surrounded by **bold** formatting -- Not listed as bullet points -- No space between headings and list items underneath - -**Emphasis**: -- **Important terms** set in bold font -- **Text ending in a colon**: also bolded - -**Review**: -- Ensure adherence to specified format -- Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST]""", - lines=3, - visible=False) - - custom_prompt_checkbox.change( - fn=lambda x: (gr.update(visible=x), gr.update(visible=x)), - inputs=[custom_prompt_checkbox], - outputs=[custom_prompt_input, system_prompt_input] + # Changed to support multiple files + pdf_file_input = gr.File( + label="Uploaded PDF Files", + file_types=[".pdf"], + visible=True, + file_count="multiple" ) - preset_prompt_checkbox.change( - fn=lambda x: gr.update(visible=x), - inputs=[preset_prompt_checkbox], - outputs=[preset_prompt] + pdf_upload_button = gr.UploadButton( + "Click to Upload PDFs", + file_types=[".pdf"], + file_count="multiple" ) + # Common metadata for all files + pdf_keywords_input = gr.Textbox(label="Keywords (Optional, comma-separated)") +# with gr.Row(): +# custom_prompt_checkbox = gr.Checkbox( +# label="Use a Custom Prompt", +# value=False, +# visible=True +# ) +# preset_prompt_checkbox = gr.Checkbox( +# label="Use a pre-set Prompt", +# value=False, +# visible=True +# ) +# # Initialize state variables for pagination +# current_page_state = gr.State(value=1) +# total_pages_state = gr.State(value=1) +# with gr.Row(): +# # Add pagination controls +# preset_prompt = gr.Dropdown( +# label="Select Preset Prompt", +# choices=[], +# visible=False +# ) +# prev_page_button = gr.Button("Previous Page", visible=False) +# page_display = gr.Markdown("Page 1 of X", visible=False) +# next_page_button = gr.Button("Next Page", visible=False) +# with gr.Row(): +# custom_prompt_input = gr.Textbox( +# label="Custom Prompt", +# placeholder="Enter custom prompt here", +# lines=3, +# visible=False +# ) +# with gr.Row(): +# system_prompt_input = gr.Textbox( +# label="System Prompt", +# value=""" +# You are a bulleted notes specialist. +# [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] +# **Bulleted Note Creation Guidelines** +# +# **Headings**: +# - Based on referenced topics, not categories like quotes or terms +# - Surrounded by **bold** formatting +# - Not listed as bullet points +# - No space between headings and list items underneath +# +# **Emphasis**: +# - **Important terms** set in bold font +# - **Text ending in a colon**: also bolded +# +# **Review**: +# - Ensure adherence to specified format +# - Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST]""", +# lines=3, +# visible=False +# ) +# +# custom_prompt_checkbox.change( +# fn=lambda x: (gr.update(visible=x), gr.update(visible=x)), +# inputs=[custom_prompt_checkbox], +# outputs=[custom_prompt_input, system_prompt_input] +# ) +# +# def on_preset_prompt_checkbox_change(is_checked): +# if is_checked: +# prompts, total_pages, current_page = list_prompts(page=1, per_page=10) +# page_display_text = f"Page {current_page} of {total_pages}" +# return ( +# gr.update(visible=True, interactive=True, choices=prompts), # preset_prompt +# gr.update(visible=True), # prev_page_button +# gr.update(visible=True), # next_page_button +# gr.update(value=page_display_text, visible=True), # page_display +# current_page, # current_page_state +# total_pages # total_pages_state +# ) +# else: +# return ( +# gr.update(visible=False, interactive=False), # preset_prompt +# gr.update(visible=False), # prev_page_button +# gr.update(visible=False), # next_page_button +# gr.update(visible=False), # page_display +# 1, # current_page_state +# 1 # total_pages_state +# ) +# +# preset_prompt_checkbox.change( +# fn=on_preset_prompt_checkbox_change, +# inputs=[preset_prompt_checkbox], +# outputs=[preset_prompt, prev_page_button, next_page_button, page_display, current_page_state, total_pages_state] +# ) +# +# def on_prev_page_click(current_page, total_pages): +# new_page = max(current_page - 1, 1) +# prompts, total_pages, current_page = list_prompts(page=new_page, per_page=10) +# page_display_text = f"Page {current_page} of {total_pages}" +# return gr.update(choices=prompts), gr.update(value=page_display_text), current_page +# +# prev_page_button.click( +# fn=on_prev_page_click, +# inputs=[current_page_state, total_pages_state], +# outputs=[preset_prompt, page_display, current_page_state] +# ) +# +# def on_next_page_click(current_page, total_pages): +# new_page = min(current_page + 1, total_pages) +# prompts, total_pages, current_page = list_prompts(page=new_page, per_page=10) +# page_display_text = f"Page {current_page} of {total_pages}" +# return gr.update(choices=prompts), gr.update(value=page_display_text), current_page +# +# next_page_button.click( +# fn=on_next_page_click, +# inputs=[current_page_state, total_pages_state], +# outputs=[preset_prompt, page_display, current_page_state] +# ) +# +# def update_prompts(preset_name): +# prompts = update_user_prompt(preset_name) +# return ( +# gr.update(value=prompts["user_prompt"], visible=True), +# gr.update(value=prompts["system_prompt"], visible=True) +# ) +# +# preset_prompt.change( +# update_prompts, +# inputs=preset_prompt, +# outputs=[custom_prompt_input, system_prompt_input] +# ) - def update_prompts(preset_name): - prompts = update_user_prompt(preset_name) - return ( - gr.update(value=prompts["user_prompt"], visible=True), - gr.update(value=prompts["system_prompt"], visible=True) - ) - - preset_prompt.change( - update_prompts, - inputs=preset_prompt, - outputs=[custom_prompt_input, system_prompt_input] - ) + pdf_ingest_button = gr.Button("Ingest PDFs") - pdf_ingest_button = gr.Button("Ingest PDF") + # Update the upload button handler for multiple files + pdf_upload_button.upload( + fn=lambda files: files, + inputs=pdf_upload_button, + outputs=pdf_file_input + ) - pdf_upload_button.upload(fn=lambda file: file, inputs=pdf_upload_button, outputs=pdf_file_input) with gr.Column(): - pdf_result_output = gr.Textbox(label="Result") + pdf_result_output = gr.DataFrame( + headers=["Filename", "Status", "Message"], + label="Processing Results" + ) + + # Define a new function to handle multiple PDFs + def process_multiple_pdfs(pdf_files, keywords, custom_prompt_checkbox_value, custom_prompt_text, system_prompt_text): + results = [] + if pdf_files is None: + return [["No files", "Error", "No files uploaded"]] + + for pdf_file in pdf_files: + try: + # Extract metadata from PDF + metadata = extract_metadata_from_pdf(pdf_file.name) + + # Use custom or system prompt if checkbox is checked + if custom_prompt_checkbox_value: + prompt = custom_prompt_text + system_prompt = system_prompt_text + else: + prompt = None + system_prompt = None + + # Process the PDF with prompts + result = process_and_cleanup_pdf( + pdf_file, + metadata.get('title', os.path.splitext(os.path.basename(pdf_file.name))[0]), + metadata.get('author', 'Unknown'), + keywords, + #prompt=prompt, + #system_prompt=system_prompt + ) + + results.append([ + pdf_file.name, + "Success" if "successfully" in result else "Error", + result + ]) + except Exception as e: + results.append([ + pdf_file.name, + "Error", + str(e) + ]) + + return results + # Update the ingest button click handler pdf_ingest_button.click( - fn=process_and_cleanup_pdf, - inputs=[pdf_file_input, pdf_title_input, pdf_author_input, pdf_keywords_input], + fn=process_multiple_pdfs, + inputs=[ + pdf_file_input, + pdf_keywords_input, + #custom_prompt_checkbox, + #custom_prompt_input, + #system_prompt_input + ], outputs=pdf_result_output ) diff --git a/App_Function_Libraries/Gradio_UI/Plaintext_tab_import.py b/App_Function_Libraries/Gradio_UI/Plaintext_tab_import.py index 7c02f810a..66bdc682f 100644 --- a/App_Function_Libraries/Gradio_UI/Plaintext_tab_import.py +++ b/App_Function_Libraries/Gradio_UI/Plaintext_tab_import.py @@ -17,7 +17,6 @@ from pypandoc import convert_file # # Import Local libraries -from App_Function_Libraries.Gradio_UI.Import_Functionality import import_data from App_Function_Libraries.Plaintext.Plaintext_Files import import_file_handler from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name # @@ -36,36 +35,58 @@ def create_plain_text_import_tab(): except Exception as e: logging.error(f"Error setting default API endpoint: {str(e)}") default_value = None + with gr.TabItem("Import Plain text & .docx Files", visible=True): with gr.Row(): with gr.Column(): - gr.Markdown("# Import Markdown(`.md`)/Text(`.txt`)/rtf & `.docx` Files") - gr.Markdown("Upload a single file or a zip file containing multiple files") - import_file = gr.File(label="Upload file for import", file_types=[".md", ".txt", ".rtf", ".docx", ".zip"]) - title_input = gr.Textbox(label="Title", placeholder="Enter the title of the content (for single files)") - author_input = gr.Textbox(label="Author", placeholder="Enter the author's name (for single files)") - keywords_input = gr.Textbox(label="Keywords", placeholder="Enter keywords, comma-separated") - system_prompt_input = gr.Textbox(label="System Prompt (for Summarization)", lines=3, - value="""You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] - **Bulleted Note Creation Guidelines** - - **Headings**: - - Based on referenced topics, not categories like quotes or terms - - Surrounded by **bold** formatting - - Not listed as bullet points - - No space between headings and list items underneath + gr.Markdown("# Import `.md`/`.txt`/`.rtf`/`.docx` Files & `.zip` collections of them.") + gr.Markdown("Upload multiple files or a zip file containing multiple files") - **Emphasis**: - - **Important terms** set in bold font - - **Text ending in a colon**: also bolded + # Updated to support multiple files + import_files = gr.File( + label="Upload files for import", + file_count="multiple", + file_types=[".md", ".txt", ".rtf", ".docx", ".zip"] + ) - **Review**: - - Ensure adherence to specified format - - Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST]""", - ) - custom_prompt_input = gr.Textbox(label="Custom User Prompt", placeholder="Enter a custom user prompt for summarization (optional)") + # Optional metadata override fields + author_input = gr.Textbox( + label="Author Override (optional)", + placeholder="Enter author name to apply to all files" + ) + keywords_input = gr.Textbox( + label="Keywords", + placeholder="Enter keywords, comma-separated - will be applied to all files" + ) + system_prompt_input = gr.Textbox( + label="System Prompt (for Summarization)", + lines=3, + value=""" + You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] + **Bulleted Note Creation Guidelines** + + **Headings**: + - Based on referenced topics, not categories like quotes or terms + - Surrounded by **bold** formatting + - Not listed as bullet points + - No space between headings and list items underneath + + **Emphasis**: + - **Important terms** set in bold font + - **Text ending in a colon**: also bolded + + **Review**: + - Ensure adherence to specified format + - Do not reference these instructions in your response.[INST] + """ + ) + custom_prompt_input = gr.Textbox( + label="Custom User Prompt", + placeholder="Enter a custom user prompt for summarization (optional)" + ) auto_summarize_checkbox = gr.Checkbox(label="Auto-summarize", value=False) - # Refactored API selection dropdown + + # API configuration api_name_input = gr.Dropdown( choices=["None"] + [format_api_name(api) for api in global_api_endpoints], value=default_value, @@ -73,14 +94,27 @@ def create_plain_text_import_tab(): ) api_key_input = gr.Textbox(label="API Key", type="password") import_button = gr.Button("Import File(s)") + with gr.Column(): - import_output = gr.Textbox(label="Import Status") + import_output = gr.Textbox(label="Import Status", lines=10) import_button.click( fn=import_file_handler, - inputs=[import_file, title_input, author_input, keywords_input, system_prompt_input, - custom_prompt_input, auto_summarize_checkbox, api_name_input, api_key_input], + inputs=[ + import_files, + author_input, + keywords_input, + system_prompt_input, + custom_prompt_input, + auto_summarize_checkbox, + api_name_input, + api_key_input + ], outputs=import_output ) - return import_file, title_input, author_input, keywords_input, system_prompt_input, custom_prompt_input, auto_summarize_checkbox, api_name_input, api_key_input, import_button, import_output \ No newline at end of file + return import_files, author_input, keywords_input, system_prompt_input, custom_prompt_input, auto_summarize_checkbox, api_name_input, api_key_input, import_button, import_output + +# +# End of Plain_text_import.py +####################################################################################################################### diff --git a/App_Function_Libraries/Gradio_UI/Podcast_tab.py b/App_Function_Libraries/Gradio_UI/Podcast_tab.py index 2372c1277..84c68be86 100644 --- a/App_Function_Libraries/Gradio_UI/Podcast_tab.py +++ b/App_Function_Libraries/Gradio_UI/Podcast_tab.py @@ -2,25 +2,21 @@ # Description: Gradio UI for ingesting podcasts into the database # # Imports +import logging # # External Imports -import logging - import gradio as gr # # Local Imports from App_Function_Libraries.Audio.Audio_Files import process_podcast -from App_Function_Libraries.DB.DB_Manager import load_preset_prompts +from App_Function_Libraries.DB.DB_Manager import list_prompts from App_Function_Libraries.Gradio_UI.Gradio_Shared import whisper_models, update_user_prompt from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name - - # ######################################################################################################################## # # Functions: - def create_podcast_tab(): try: default_value = None @@ -34,6 +30,10 @@ def create_podcast_tab(): default_value = None with gr.TabItem("Podcast", visible=True): gr.Markdown("# Podcast Transcription and Ingestion", visible=True) + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + with gr.Row(): with gr.Column(): podcast_url_input = gr.Textbox(label="Podcast URL", placeholder="Enter the podcast URL here") @@ -50,54 +50,130 @@ def create_podcast_tab(): keep_timestamps_input = gr.Checkbox(label="Keep Timestamps", value=True) with gr.Row(): - podcast_custom_prompt_checkbox = gr.Checkbox(label="Use a Custom Prompt", - value=False, - visible=True) - preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", - value=False, - visible=True) + podcast_custom_prompt_checkbox = gr.Checkbox( + label="Use a Custom Prompt", + value=False, + visible=True + ) + preset_prompt_checkbox = gr.Checkbox( + label="Use a pre-set Prompt", + value=False, + visible=True + ) + + with gr.Row(): + # Add pagination controls + preset_prompt = gr.Dropdown( + label="Select Preset Prompt", + choices=[], + visible=False + ) with gr.Row(): - preset_prompt = gr.Dropdown(label="Select Preset Prompt", - choices=load_preset_prompts(), - visible=False) + prev_page_button = gr.Button("Previous Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + next_page_button = gr.Button("Next Page", visible=False) + with gr.Row(): - podcast_custom_prompt_input = gr.Textbox(label="Custom Prompt", - placeholder="Enter custom prompt here", - lines=3, - visible=False) + podcast_custom_prompt_input = gr.Textbox( + label="Custom Prompt", + placeholder="Enter custom prompt here", + lines=10, + visible=False + ) with gr.Row(): - system_prompt_input = gr.Textbox(label="System Prompt", - value="""You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] -**Bulleted Note Creation Guidelines** - -**Headings**: -- Based on referenced topics, not categories like quotes or terms -- Surrounded by **bold** formatting -- Not listed as bullet points -- No space between headings and list items underneath - -**Emphasis**: -- **Important terms** set in bold font -- **Text ending in a colon**: also bolded - -**Review**: -- Ensure adherence to specified format -- Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST] -""", - lines=3, - visible=False) + system_prompt_input = gr.Textbox( + label="System Prompt", + value="""You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhere to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] + **Bulleted Note Creation Guidelines** + + **Headings**: + - Based on referenced topics, not categories like quotes or terms + - Surrounded by **bold** formatting + - Not listed as bullet points + - No space between headings and list items underneath + + **Emphasis**: + - **Important terms** set in bold font + - **Text ending in a colon**: also bolded + **Review**: + - Ensure adherence to specified format + - Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST] + """, + lines=10, + visible=False + ) + + # Handle custom prompt checkbox change podcast_custom_prompt_checkbox.change( fn=lambda x: (gr.update(visible=x), gr.update(visible=x)), inputs=[podcast_custom_prompt_checkbox], outputs=[podcast_custom_prompt_input, system_prompt_input] ) + + # Handle preset prompt checkbox change + def on_preset_prompt_checkbox_change(is_checked): + if is_checked: + prompts, total_pages, current_page = list_prompts(page=1, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(visible=True, interactive=True, choices=prompts), # preset_prompt + gr.update(visible=True), # prev_page_button + gr.update(visible=True), # next_page_button + gr.update(value=page_display_text, visible=True), # page_display + current_page, # current_page_state + total_pages # total_pages_state + ) + else: + return ( + gr.update(visible=False, interactive=False), # preset_prompt + gr.update(visible=False), # prev_page_button + gr.update(visible=False), # next_page_button + gr.update(visible=False), # page_display + 1, # current_page_state + 1 # total_pages_state + ) + preset_prompt_checkbox.change( - fn=lambda x: gr.update(visible=x), + fn=on_preset_prompt_checkbox_change, inputs=[preset_prompt_checkbox], - outputs=[preset_prompt] + outputs=[preset_prompt, prev_page_button, next_page_button, page_display, current_page_state, total_pages_state] + ) + + # Pagination button functions + def on_prev_page_click(current_page, total_pages): + new_page = max(current_page - 1, 1) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text), + current_page + ) + + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text), + current_page + ) + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] ) + # Update prompts when a preset is selected def update_prompts(preset_name): prompts = update_user_prompt(preset_name) return ( @@ -106,8 +182,8 @@ def update_prompts(preset_name): ) preset_prompt.change( - update_prompts, - inputs=preset_prompt, + fn=update_prompts, + inputs=[preset_prompt], outputs=[podcast_custom_prompt_input, system_prompt_input] ) @@ -166,13 +242,37 @@ def update_prompts(preset_name): podcast_process_button.click( fn=process_podcast, - inputs=[podcast_url_input, podcast_title_input, podcast_author_input, - podcast_keywords_input, podcast_custom_prompt_input, podcast_api_name_input, - podcast_api_key_input, podcast_whisper_model_input, keep_original_input, - enable_diarization_input, use_cookies_input, cookies_input, - chunk_method, max_chunk_size, chunk_overlap, use_adaptive_chunking, - use_multi_level_chunking, chunk_language, keep_timestamps_input], - outputs=[podcast_progress_output, podcast_transcription_output, podcast_summary_output, - podcast_title_input, podcast_author_input, podcast_keywords_input, podcast_error_output, - download_transcription, download_summary] + inputs=[ + podcast_url_input, + podcast_title_input, + podcast_author_input, + podcast_keywords_input, + podcast_custom_prompt_input, + podcast_api_name_input, + podcast_api_key_input, + podcast_whisper_model_input, + keep_original_input, + enable_diarization_input, + use_cookies_input, + cookies_input, + chunk_method, + max_chunk_size, + chunk_overlap, + use_adaptive_chunking, + use_multi_level_chunking, + chunk_language, + keep_timestamps_input, + system_prompt_input # Include system prompt input + ], + outputs=[ + podcast_progress_output, + podcast_transcription_output, + podcast_summary_output, + podcast_title_input, + podcast_author_input, + podcast_keywords_input, + podcast_error_output, + download_transcription, + download_summary + ] ) \ No newline at end of file diff --git a/App_Function_Libraries/Gradio_UI/Prompt_Suggestion_tab.py b/App_Function_Libraries/Gradio_UI/Prompt_Suggestion_tab.py index 418b11f32..bcbd9ca2e 100644 --- a/App_Function_Libraries/Gradio_UI/Prompt_Suggestion_tab.py +++ b/App_Function_Libraries/Gradio_UI/Prompt_Suggestion_tab.py @@ -5,8 +5,8 @@ import gradio as gr -from App_Function_Libraries.Chat import chat -from App_Function_Libraries.DB.SQLite_DB import add_or_update_prompt +from App_Function_Libraries.Chat.Chat_Functions import chat +from App_Function_Libraries.DB.DB_Manager import add_or_update_prompt from App_Function_Libraries.Prompt_Engineering.Prompt_Engineering import generate_prompt, test_generated_prompt from App_Function_Libraries.Utils.Utils import format_api_name, global_api_endpoints, default_api_endpoint diff --git a/App_Function_Libraries/Gradio_UI/Prompts_tab.py b/App_Function_Libraries/Gradio_UI/Prompts_tab.py new file mode 100644 index 000000000..dedfcf9d9 --- /dev/null +++ b/App_Function_Libraries/Gradio_UI/Prompts_tab.py @@ -0,0 +1,297 @@ +# Prompts_tab.py +# Description: This file contains the code for the prompts tab in the Gradio UI +# +# Imports +import html +import logging + +# +# External Imports +import gradio as gr +# +# Local Imports +from App_Function_Libraries.DB.DB_Manager import fetch_prompt_details, list_prompts +# +#################################################################################################### +# +# Functions: + +def create_prompt_view_tab(): + with gr.TabItem("View Prompt Database", visible=True): + gr.Markdown("# View Prompt Database Entries") + with gr.Row(): + with gr.Column(): + entries_per_page = gr.Dropdown(choices=[10, 20, 50, 100], label="Entries per Page", value=10) + page_number = gr.Number(value=1, label="Page Number", precision=0) + view_button = gr.Button("View Page") + previous_page_button = gr.Button("Previous Page", visible=True) + next_page_button = gr.Button("Next Page", visible=True) + pagination_info = gr.Textbox(label="Pagination Info", interactive=False) + prompt_selector = gr.Dropdown(label="Select Prompt to View", choices=[]) + with gr.Column(): + results_table = gr.HTML() + selected_prompt_display = gr.HTML() + + # Function to view database entries + def view_database(page, entries_per_page): + try: + # Use list_prompts to get prompts and total pages + prompts, total_pages, current_page = list_prompts(page=int(page), per_page=int(entries_per_page)) + + table_html = "" + table_html += "" + prompt_choices = [] + for prompt_name in prompts: + details = fetch_prompt_details(prompt_name) + if details: + title, author, _, _, _, _ = details + author = author or "Unknown" # Handle None author + table_html += f"" + prompt_choices.append(prompt_name) # Using prompt_name as value + table_html += "
TitleAuthor
{html.escape(title)}{html.escape(author)}
" + + # Get total prompts if possible + total_prompts = total_pages * int(entries_per_page) # This might overestimate if the last page is not full + + pagination = f"Page {current_page} of {total_pages} (Total prompts: {total_prompts})" + + return table_html, pagination, total_pages, prompt_choices + except Exception as e: + return f"

Error fetching prompts: {e}

", "Error", 0, [] + + # Function to update page content + def update_page(page, entries_per_page): + results, pagination, total_pages, prompt_choices = view_database(page, entries_per_page) + page = int(page) + next_disabled = page >= total_pages + prev_disabled = page <= 1 + return ( + results, + pagination, + page, + gr.update(visible=True, interactive=not prev_disabled), # previous_page_button + gr.update(visible=True, interactive=not next_disabled), # next_page_button + gr.update(choices=prompt_choices) + ) + + # Function to go to the next page + def go_to_next_page(current_page, entries_per_page): + next_page = int(current_page) + 1 + return update_page(next_page, entries_per_page) + + # Function to go to the previous page + def go_to_previous_page(current_page, entries_per_page): + previous_page = max(1, int(current_page) - 1) + return update_page(previous_page, entries_per_page) + + # Function to display selected prompt details + def display_selected_prompt(prompt_name): + details = fetch_prompt_details(prompt_name) + if details: + title, author, description, system_prompt, user_prompt, keywords = details + # Handle None values by converting them to empty strings + description = description or "" + system_prompt = system_prompt or "" + user_prompt = user_prompt or "" + author = author or "Unknown" + keywords = keywords or "" + + html_content = f""" +
+

{html.escape(title)}

by {html.escape(author)}

+

Description: {html.escape(description)}

+
+ System Prompt: +
{html.escape(system_prompt)}
+
+
+ User Prompt: +
{html.escape(user_prompt)}
+
+

Keywords: {html.escape(keywords)}

+
+ """ + return html_content + else: + return "

Prompt not found.

" + + # Event handlers + view_button.click( + fn=update_page, + inputs=[page_number, entries_per_page], + outputs=[results_table, pagination_info, page_number, previous_page_button, next_page_button, prompt_selector] + ) + + next_page_button.click( + fn=go_to_next_page, + inputs=[page_number, entries_per_page], + outputs=[results_table, pagination_info, page_number, previous_page_button, next_page_button, prompt_selector] + ) + + previous_page_button.click( + fn=go_to_previous_page, + inputs=[page_number, entries_per_page], + outputs=[results_table, pagination_info, page_number, previous_page_button, next_page_button, prompt_selector] + ) + + prompt_selector.change( + fn=display_selected_prompt, + inputs=[prompt_selector], + outputs=[selected_prompt_display] + ) + + + +def create_prompts_export_tab(): + """Creates a tab for exporting prompts database content with multiple format options""" + with gr.TabItem("Export Prompts", visible=True): + gr.Markdown("# Export Prompts Database Content") + + with gr.Row(): + with gr.Column(): + export_type = gr.Radio( + choices=["All Prompts", "Prompts by Keyword"], + label="Export Type", + value="All Prompts" + ) + + # Keyword selection for filtered export + with gr.Column(visible=False) as keyword_col: + keyword_input = gr.Textbox( + label="Enter Keywords (comma-separated)", + placeholder="Enter keywords to filter prompts..." + ) + + # Export format selection + export_format = gr.Radio( + choices=["CSV", "Markdown (ZIP)"], + label="Export Format", + value="CSV" + ) + + # Export options + include_options = gr.CheckboxGroup( + choices=[ + "Include System Prompts", + "Include User Prompts", + "Include Details", + "Include Author", + "Include Keywords" + ], + label="Export Options", + value=["Include Keywords", "Include Author"] + ) + + # Markdown-specific options (only visible when Markdown is selected) + with gr.Column(visible=False) as markdown_options_col: + markdown_template = gr.Radio( + choices=[ + "Basic Template", + "Detailed Template", + "Custom Template" + ], + label="Markdown Template", + value="Basic Template" + ) + custom_template = gr.Textbox( + label="Custom Template", + placeholder="Use {title}, {author}, {details}, {system}, {user}, {keywords} as placeholders", + visible=False + ) + + export_button = gr.Button("Export Prompts") + + with gr.Column(): + export_status = gr.Textbox(label="Export Status", interactive=False) + export_file = gr.File(label="Download Export") + + def update_ui_visibility(export_type, format_choice, template_choice): + """Update UI elements visibility based on selections""" + show_keywords = export_type == "Prompts by Keyword" + show_markdown_options = format_choice == "Markdown (ZIP)" + show_custom_template = template_choice == "Custom Template" and show_markdown_options + + return [ + gr.update(visible=show_keywords), # keyword_col + gr.update(visible=show_markdown_options), # markdown_options_col + gr.update(visible=show_custom_template) # custom_template + ] + + def handle_export(export_type, keywords, export_format, options, markdown_template, custom_template): + """Handle the export process based on selected options""" + try: + # Parse options + include_system = "Include System Prompts" in options + include_user = "Include User Prompts" in options + include_details = "Include Details" in options + include_author = "Include Author" in options + include_keywords = "Include Keywords" in options + + # Handle keyword filtering + keyword_list = None + if export_type == "Prompts by Keyword" and keywords: + keyword_list = [k.strip() for k in keywords.split(",") if k.strip()] + + # Get the appropriate template + template = None + if export_format == "Markdown (ZIP)": + if markdown_template == "Custom Template": + template = custom_template + else: + template = markdown_template + + # Perform export + from App_Function_Libraries.DB.Prompts_DB import export_prompts + status, file_path = export_prompts( + export_format=export_format.split()[0].lower(), # 'csv' or 'markdown' + filter_keywords=keyword_list, + include_system=include_system, + include_user=include_user, + include_details=include_details, + include_author=include_author, + include_keywords=include_keywords, + markdown_template=template + ) + + return status, file_path + + except Exception as e: + error_msg = f"Export failed: {str(e)}" + logging.error(error_msg) + return error_msg, None + + # Event handlers + export_type.change( + fn=lambda t, f, m: update_ui_visibility(t, f, m), + inputs=[export_type, export_format, markdown_template], + outputs=[keyword_col, markdown_options_col, custom_template] + ) + + export_format.change( + fn=lambda t, f, m: update_ui_visibility(t, f, m), + inputs=[export_type, export_format, markdown_template], + outputs=[keyword_col, markdown_options_col, custom_template] + ) + + markdown_template.change( + fn=lambda t, f, m: update_ui_visibility(t, f, m), + inputs=[export_type, export_format, markdown_template], + outputs=[keyword_col, markdown_options_col, custom_template] + ) + + export_button.click( + fn=handle_export, + inputs=[ + export_type, + keyword_input, + export_format, + include_options, + markdown_template, + custom_template + ], + outputs=[export_status, export_file] + ) + +# +# End of Prompts_tab.py +#################################################################################################### diff --git a/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py b/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py index 76d316e0e..0c88aaeb3 100644 --- a/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py +++ b/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py @@ -6,6 +6,7 @@ import logging import json import os +import re from datetime import datetime # # External Imports @@ -14,27 +15,21 @@ # # Local Imports from App_Function_Libraries.Books.Book_Ingestion_Lib import read_epub -from App_Function_Libraries.DB.DB_Manager import DatabaseError, get_paginated_files, add_media_with_keywords -from App_Function_Libraries.DB.RAG_QA_Chat_DB import ( - save_notes, - add_keywords_to_note, - start_new_conversation, - save_message, - search_conversations_by_keywords, - load_chat_history, - get_all_conversations, - get_note_by_id, - get_notes_by_keywords, - get_notes_by_keyword_collection, - update_note, - clear_keywords_from_note, get_notes, get_keywords_for_note, delete_conversation, delete_note, execute_query, - add_keywords_to_conversation, fetch_all_notes, fetch_all_conversations, fetch_conversations_by_ids, - fetch_notes_by_ids, -) +from App_Function_Libraries.DB.Character_Chat_DB import search_character_chat, search_character_cards +from App_Function_Libraries.DB.DB_Manager import DatabaseError, get_paginated_files, add_media_with_keywords, \ + get_all_conversations, get_note_by_id, get_notes_by_keywords, start_new_conversation, update_note, save_notes, \ + clear_keywords_from_note, add_keywords_to_note, load_chat_history, save_message, add_keywords_to_conversation, \ + get_keywords_for_note, delete_note, search_conversations_by_keywords, get_conversation_title, delete_conversation, \ + update_conversation_title, fetch_all_conversations, fetch_all_notes, fetch_conversations_by_ids, fetch_notes_by_ids, \ + search_media_db, search_notes_titles, list_prompts +from App_Function_Libraries.DB.RAG_QA_Chat_DB import get_notes, delete_messages_in_conversation, search_rag_notes, \ + search_rag_chat, get_conversation_rating, set_conversation_rating +from App_Function_Libraries.Gradio_UI.Gradio_Shared import update_user_prompt from App_Function_Libraries.PDF.PDF_Ingestion_Lib import extract_text_and_format_from_pdf from App_Function_Libraries.RAG.RAG_Library_2 import generate_answer, enhanced_rag_pipeline from App_Function_Libraries.RAG.RAG_QA_Chat import search_database, rag_qa_chat -from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name +from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name, \ + load_comprehensive_config # @@ -60,18 +55,53 @@ def create_rag_qa_chat_tab(): "page": 1, "context_source": "Entire Media Database", "conversation_messages": [], + "conversation_id": None }) note_state = gr.State({"note_id": None}) + def auto_save_conversation(message, response, state_value, auto_save_enabled): + """Automatically save the conversation if auto-save is enabled""" + try: + if not auto_save_enabled: + return state_value + + conversation_id = state_value.get("conversation_id") + if not conversation_id: + # Create new conversation with default title + title = "Auto-saved Conversation " + datetime.now().strftime("%Y-%m-%d %H:%M:%S") + conversation_id = start_new_conversation(title=title) + state_value = state_value.copy() + state_value["conversation_id"] = conversation_id + + # Save the messages + save_message(conversation_id, "user", message) + save_message(conversation_id, "assistant", response) + + return state_value + except Exception as e: + logging.error(f"Error in auto-save: {str(e)}") + return state_value + # Update the conversation list function def update_conversation_list(): conversations, total_pages, total_count = get_all_conversations() - choices = [f"{title} (ID: {conversation_id})" for conversation_id, title in conversations] + choices = [ + f"{conversation['title']} (ID: {conversation['conversation_id']}) - Rating: {conversation['rating'] or 'Not Rated'}" + for conversation in conversations + ] return choices with gr.Row(): with gr.Column(scale=1): + # FIXME - Offer the user to search 2+ databases at once + database_types = ["Media DB", "RAG Chat", "RAG Notes", "Character Chat", "Character Cards"] + db_choice = gr.CheckboxGroup( + label="Select Database(s)", + choices=database_types, + value=["Media DB"], + interactive=True + ) context_source = gr.Radio( ["All Files in the Database", "Search Database", "Upload File"], label="Context Source", @@ -84,19 +114,52 @@ def update_conversation_list(): next_page_btn = gr.Button("Next Page") page_info = gr.HTML("Page 1") top_k_input = gr.Number(value=10, label="Maximum amount of results to use (Default: 10)", minimum=1, maximum=50, step=1, precision=0, interactive=True) - keywords_input = gr.Textbox(label="Keywords (comma-separated) to filter results by)", visible=True) + keywords_input = gr.Textbox(label="Keywords (comma-separated) to filter results by)", value="rag_qa_default_keyword" ,visible=True) use_query_rewriting = gr.Checkbox(label="Use Query Rewriting", value=True) use_re_ranking = gr.Checkbox(label="Use Re-ranking", value=True) - # with gr.Row(): - # page_number = gr.Number(value=1, label="Page", precision=0) - # page_size = gr.Number(value=20, label="Items per page", precision=0) - # total_pages = gr.Number(label="Total Pages", interactive=False) + config = load_comprehensive_config() + auto_save_value = config.getboolean('auto-save', 'save_character_chats', fallback=False) + auto_save_checkbox = gr.Checkbox( + label="Save chats automatically", + value=auto_save_value, + info="When enabled, conversations will be saved automatically after each message" + ) + initial_prompts, total_pages, current_page = list_prompts(page=1, per_page=10) + + preset_prompt_checkbox = gr.Checkbox( + label="View Custom Prompts(have to copy/paste them)", + value=False, + visible=True + ) + + with gr.Row(visible=False) as preset_prompt_controls: + prev_prompt_page = gr.Button("Previous") + current_prompt_page_text = gr.Text(f"Page {current_page} of {total_pages}") + next_prompt_page = gr.Button("Next") + current_prompt_page_state = gr.State(value=1) + + preset_prompt = gr.Dropdown( + label="Select Preset Prompt", + choices=initial_prompts, + visible=False + ) + user_prompt = gr.Textbox( + label="Custom Prompt", + placeholder="Enter custom prompt here", + lines=3, + visible=False + ) + + system_prompt_input = gr.Textbox( + label="System Prompt", + lines=3, + visible=False + ) search_query = gr.Textbox(label="Search Query", visible=False) search_button = gr.Button("Search", visible=False) search_results = gr.Dropdown(label="Search Results", choices=[], visible=False) - # FIXME - Add pages for search results handling file_upload = gr.File( label="Upload File", visible=False, @@ -108,14 +171,23 @@ def update_conversation_list(): load_conversation = gr.Dropdown( label="Load Conversation", choices=update_conversation_list() - ) + ) new_conversation = gr.Button("New Conversation") save_conversation_button = gr.Button("Save Conversation") conversation_title = gr.Textbox( - label="Conversation Title", placeholder="Enter a title for the new conversation" + label="Conversation Title", + placeholder="Enter a title for the new conversation" ) keywords = gr.Textbox(label="Keywords (comma-separated)", visible=True) + # Add the rating display and input + rating_display = gr.Markdown(value="", visible=False) + rating_input = gr.Radio( + choices=["1", "2", "3"], + label="Rate this Conversation (1-3 stars)", + visible=False + ) + # Refactored API selection dropdown api_choice = gr.Dropdown( choices=["None"] + [format_api_name(api) for api in global_api_endpoints], @@ -143,6 +215,8 @@ def update_conversation_list(): clear_notes_btn = gr.Button("Clear Current Note text") new_note_btn = gr.Button("New Note") + # FIXME - Change from only keywords to generalized search + search_notes_title = gr.Textbox(label="Search Notes by Title") search_notes_by_keyword = gr.Textbox(label="Search Notes by Keyword") search_notes_button = gr.Button("Search Notes") note_results = gr.Dropdown(label="Notes", choices=[]) @@ -150,8 +224,58 @@ def update_conversation_list(): loading_indicator = gr.HTML("Loading...", visible=False) status_message = gr.HTML() + auto_save_status = gr.HTML() + + # Function Definitions + def update_prompt_page(direction, current_page_val): + new_page = max(1, min(total_pages, current_page_val + direction)) + prompts, _, _ = list_prompts(page=new_page, per_page=10) + return ( + gr.update(choices=prompts), + gr.update(value=f"Page {new_page} of {total_pages}"), + new_page + ) + + def update_prompts(preset_name): + prompts = update_user_prompt(preset_name) + return ( + gr.update(value=prompts["user_prompt"], visible=True), + gr.update(value=prompts["system_prompt"], visible=True) + ) + + def toggle_preset_prompt(checkbox_value): + return ( + gr.update(visible=checkbox_value), + gr.update(visible=checkbox_value), + gr.update(visible=False), + gr.update(visible=False) + ) + + prev_prompt_page.click( + lambda x: update_prompt_page(-1, x), + inputs=[current_prompt_page_state], + outputs=[preset_prompt, current_prompt_page_text, current_prompt_page_state] + ) + + next_prompt_page.click( + lambda x: update_prompt_page(1, x), + inputs=[current_prompt_page_state], + outputs=[preset_prompt, current_prompt_page_text, current_prompt_page_state] + ) + + preset_prompt.change( + update_prompts, + inputs=preset_prompt, + outputs=[user_prompt, system_prompt_input] + ) + + preset_prompt_checkbox.change( + toggle_preset_prompt, + inputs=[preset_prompt_checkbox], + outputs=[preset_prompt, preset_prompt_controls, user_prompt, system_prompt_input] + ) def update_state(state, **kwargs): new_state = state.copy() @@ -166,18 +290,28 @@ def create_new_note(): outputs=[note_title, notes, note_state] ) - def search_notes(keywords): + def search_notes(search_notes_title, keywords): if keywords: keywords_list = [kw.strip() for kw in keywords.split(',')] notes_data, total_pages, total_count = get_notes_by_keywords(keywords_list) - choices = [f"Note {note_id} ({timestamp})" for note_id, title, content, timestamp in notes_data] - return gr.update(choices=choices) + choices = [f"Note {note_id} - {title} ({timestamp})" for + note_id, title, content, timestamp, conversation_id in notes_data] + return gr.update(choices=choices, label=f"Found {total_count} notes") + elif search_notes_title: + notes_data, total_pages, total_count = search_notes_titles(search_notes_title) + choices = [f"Note {note_id} - {title} ({timestamp})" for + note_id, title, content, timestamp, conversation_id in notes_data] + return gr.update(choices=choices, label=f"Found {total_count} notes") else: - return gr.update(choices=[]) + # This will now return all notes, ordered by timestamp + notes_data, total_pages, total_count = search_notes_titles("") + choices = [f"Note {note_id} - {title} ({timestamp})" for + note_id, title, content, timestamp, conversation_id in notes_data] + return gr.update(choices=choices, label=f"All notes ({total_count} total)") search_notes_button.click( search_notes, - inputs=[search_notes_by_keyword], + inputs=[search_notes_title, search_notes_by_keyword], outputs=[note_results] ) @@ -273,83 +407,112 @@ def clear_notes_function(): outputs=[notes, note_state] ) - def update_conversation_list(): - conversations, total_pages, total_count = get_all_conversations() - choices = [f"{title} (ID: {conversation_id})" for conversation_id, title in conversations] - return choices - # Initialize the conversation list load_conversation.choices = update_conversation_list() def load_conversation_history(selected_conversation, state_value): - if selected_conversation: - conversation_id = selected_conversation.split('(ID: ')[1][:-1] + try: + if not selected_conversation: + return [], state_value, "", gr.update(value="", visible=False), gr.update(visible=False) + # Extract conversation ID + match = re.search(r'\(ID: ([0-9a-fA-F\-]+)\)', selected_conversation) + if not match: + logging.error(f"Invalid conversation format: {selected_conversation}") + return [], state_value, "", gr.update(value="", visible=False), gr.update(visible=False) + conversation_id = match.group(1) chat_data, total_pages_val, _ = load_chat_history(conversation_id, 1, 50) - # Convert chat data to list of tuples (user_message, assistant_response) + # Update state with valid conversation id + updated_state = state_value.copy() + updated_state["conversation_id"] = conversation_id + updated_state["conversation_messages"] = chat_data + # Format chat history history = [] for role, content in chat_data: if role == 'user': history.append((content, '')) - else: - if history: - history[-1] = (history[-1][0], content) - else: - history.append(('', content)) - # Retrieve notes + elif history: + history[-1] = (history[-1][0], content) + # Fetch and display the conversation rating + rating = get_conversation_rating(conversation_id) + if rating is not None: + rating_text = f"**Current Rating:** {rating} star(s)" + rating_display_update = gr.update(value=rating_text, visible=True) + rating_input_update = gr.update(value=str(rating), visible=True) + else: + rating_display_update = gr.update(value="**Current Rating:** Not Rated", visible=True) + rating_input_update = gr.update(value=None, visible=True) notes_content = get_notes(conversation_id) - updated_state = update_state(state_value, conversation_id=conversation_id, page=1, - conversation_messages=[]) - return history, updated_state, "\n".join(notes_content) - return [], state_value, "" + return history, updated_state, "\n".join( + notes_content) if notes_content else "", rating_display_update, rating_input_update + except Exception as e: + logging.error(f"Error loading conversation: {str(e)}") + return [], state_value, "", gr.update(value="", visible=False), gr.update(visible=False) load_conversation.change( load_conversation_history, inputs=[load_conversation, state], - outputs=[chatbot, state, notes] + outputs=[chatbot, state, notes, rating_display, rating_input] ) # Modify save_conversation_function to use gr.update() - def save_conversation_function(conversation_title_text, keywords_text, state_value): + def save_conversation_function(conversation_title_text, keywords_text, rating_value, state_value): conversation_messages = state_value.get("conversation_messages", []) + conversation_id = state_value.get("conversation_id") if not conversation_messages: return gr.update( value="

No conversation to save.

" - ), state_value, gr.update() - # Start a new conversation in the database - new_conversation_id = start_new_conversation( - conversation_title_text if conversation_title_text else "Untitled Conversation" - ) + ), state_value, gr.update(), gr.update(value="", visible=False), gr.update(visible=False) + # Start a new conversation in the database if not existing + if not conversation_id: + conversation_id = start_new_conversation( + conversation_title_text if conversation_title_text else "Untitled Conversation" + ) + else: + # Update the conversation title if it has changed + update_conversation_title(conversation_id, conversation_title_text) # Save the messages for role, content in conversation_messages: - save_message(new_conversation_id, role, content) + save_message(conversation_id, role, content) # Save keywords if provided if keywords_text: - add_keywords_to_conversation(new_conversation_id, [kw.strip() for kw in keywords_text.split(',')]) + add_keywords_to_conversation(conversation_id, [kw.strip() for kw in keywords_text.split(',')]) + # Save the rating if provided + try: + if rating_value: + set_conversation_rating(conversation_id, int(rating_value)) + except ValueError as ve: + logging.error(f"Invalid rating value: {ve}") + return gr.update( + value=f"

Invalid rating: {ve}

" + ), state_value, gr.update(), gr.update(value="", visible=False), gr.update(visible=False) + # Update state - updated_state = update_state(state_value, conversation_id=new_conversation_id) + updated_state = update_state(state_value, conversation_id=conversation_id) # Update the conversation list conversation_choices = update_conversation_list() + # Reset rating display and input + rating_display_update = gr.update(value=f"**Current Rating:** {rating_value} star(s)", visible=True) + rating_input_update = gr.update(value=rating_value, visible=True) return gr.update( value="

Conversation saved successfully.

" - ), updated_state, gr.update(choices=conversation_choices) + ), updated_state, gr.update(choices=conversation_choices), rating_display_update, rating_input_update save_conversation_button.click( save_conversation_function, - inputs=[conversation_title, keywords, state], - outputs=[status_message, state, load_conversation] + inputs=[conversation_title, keywords, rating_input, state], + outputs=[status_message, state, load_conversation, rating_display, rating_input] ) def start_new_conversation_wrapper(title, state_value): - # Reset the state with no conversation_id - updated_state = update_state(state_value, conversation_id=None, page=1, - conversation_messages=[]) - # Clear the chat history - return [], updated_state + # Reset the state with no conversation_id and empty conversation messages + updated_state = update_state(state_value, conversation_id=None, page=1, conversation_messages=[]) + # Clear the chat history and reset rating components + return [], updated_state, gr.update(value="", visible=False), gr.update(value=None, visible=False) new_conversation.click( start_new_conversation_wrapper, inputs=[conversation_title, state], - outputs=[chatbot, state] + outputs=[chatbot, state, rating_display, rating_input] ) def update_file_list(page): @@ -364,11 +527,12 @@ def prev_page_fn(current_page): return update_file_list(max(1, current_page - 1)) def update_context_source(choice): + # Update visibility based on context source choice return { existing_file: gr.update(visible=choice == "Existing File"), - prev_page_btn: gr.update(visible=choice == "Existing File"), - next_page_btn: gr.update(visible=choice == "Existing File"), - page_info: gr.update(visible=choice == "Existing File"), + prev_page_btn: gr.update(visible=choice == "Search Database"), + next_page_btn: gr.update(visible=choice == "Search Database"), + page_info: gr.update(visible=choice == "Search Database"), search_query: gr.update(visible=choice == "Search Database"), search_button: gr.update(visible=choice == "Search Database"), search_results: gr.update(visible=choice == "Search Database"), @@ -388,17 +552,36 @@ def update_context_source(choice): context_source.change(lambda choice: update_file_list(1) if choice == "Existing File" else (gr.update(), gr.update(), 1), inputs=[context_source], outputs=[existing_file, page_info, file_page]) - def perform_search(query): + def perform_search(query, selected_databases, keywords): try: - results = search_database(query) + results = [] + + # Iterate over selected database types and perform searches accordingly + for database_type in selected_databases: + if database_type == "Media DB": + # FIXME - check for existence of keywords before setting as search field + search_fields = ["title", "content", "keywords"] + results += search_media_db(query, search_fields, keywords, page=1, results_per_page=25) + elif database_type == "RAG Chat": + results += search_rag_chat(query) + elif database_type == "RAG Notes": + results += search_rag_notes(query) + elif database_type == "Character Chat": + results += search_character_chat(query) + elif database_type == "Character Cards": + results += search_character_cards(query) + + # Remove duplicate results if necessary + results = list(set(results)) return gr.update(choices=results) except Exception as e: gr.Error(f"Error performing search: {str(e)}") return gr.update(choices=[]) + # Click Event for the DB Search Button search_button.click( perform_search, - inputs=[search_query], + inputs=[search_query, db_choice, keywords_input], outputs=[search_results] ) @@ -420,17 +603,22 @@ def rephrase_question(history, latest_question, api_choice): logging.info(f"Rephrased question: {rephrased_question}") return rephrased_question.strip() - def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_results, file_upload, - convert_to_text, keywords, api_choice, use_query_rewriting, state_value, - keywords_input, top_k_input, use_re_ranking): + # FIXME - RAG DB selection + def rag_qa_chat_wrapper( + message, history, context_source, existing_file, search_results, file_upload, + convert_to_text, keywords, api_choice, use_query_rewriting, state_value, + keywords_input, top_k_input, use_re_ranking, db_choices, auto_save_enabled + ): try: logging.info(f"Starting rag_qa_chat_wrapper with message: {message}") logging.info(f"Context source: {context_source}") logging.info(f"API choice: {api_choice}") logging.info(f"Query rewriting: {'enabled' if use_query_rewriting else 'disabled'}") + logging.info(f"Selected DB Choices: {db_choices}") # Show loading indicator - yield history, "", gr.update(visible=True), state_value + yield history, "", gr.update(visible=True), state_value, gr.update(visible=False), gr.update( + visible=False) conversation_id = state_value.get("conversation_id") conversation_messages = state_value.get("conversation_messages", []) @@ -444,12 +632,12 @@ def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_ state_value["conversation_messages"] = conversation_messages # Ensure api_choice is a string - api_choice = api_choice.value if isinstance(api_choice, gr.components.Dropdown) else api_choice - logging.info(f"Resolved API choice: {api_choice}") + api_choice_str = api_choice.value if isinstance(api_choice, gr.components.Dropdown) else api_choice + logging.info(f"Resolved API choice: {api_choice_str}") # Only rephrase the question if it's not the first query and query rewriting is enabled if len(history) > 0 and use_query_rewriting: - rephrased_question = rephrase_question(history, message, api_choice) + rephrased_question = rephrase_question(history, message, api_choice_str) logging.info(f"Original question: {message}") logging.info(f"Rephrased question: {rephrased_question}") else: @@ -457,18 +645,20 @@ def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_ logging.info(f"Using original question: {message}") if context_source == "All Files in the Database": - # Use the enhanced_rag_pipeline to search the entire database - context = enhanced_rag_pipeline(rephrased_question, api_choice, keywords_input, top_k_input, - use_re_ranking) + # Use the enhanced_rag_pipeline to search the selected databases + context = enhanced_rag_pipeline( + rephrased_question, api_choice_str, keywords_input, top_k_input, use_re_ranking, + database_types=db_choices # Pass the list of selected databases + ) logging.info(f"Using enhanced_rag_pipeline for database search") elif context_source == "Search Database": context = f"media_id:{search_results.split('(ID: ')[1][:-1]}" logging.info(f"Using search result with context: {context}") - else: # Upload File + else: + # Upload File logging.info("Processing uploaded file") if file_upload is None: raise ValueError("No file uploaded") - # Process the uploaded file file_path = file_upload.name file_name = os.path.basename(file_path) @@ -481,7 +671,6 @@ def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_ logging.info("Reading file content") with open(file_path, 'r', encoding='utf-8') as f: content = f.read() - logging.info(f"File content length: {len(content)} characters") # Process keywords @@ -503,18 +692,17 @@ def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_ author='Unknown', ingestion_date=datetime.now().strftime('%Y-%m-%d') ) - logging.info(f"Result from add_media_with_keywords: {result}") if isinstance(result, tuple): media_id, _ = result else: media_id = result - context = f"media_id:{media_id}" logging.info(f"Context for uploaded file: {context}") logging.info("Calling rag_qa_chat function") - new_history, response = rag_qa_chat(rephrased_question, history, context, api_choice) + new_history, response = rag_qa_chat(rephrased_question, history, context, api_choice_str) + # Log first 100 chars of response logging.info(f"Response received from rag_qa_chat: {response[:100]}...") @@ -526,7 +714,8 @@ def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_ state_value["conversation_messages"] = conversation_messages # Update the state - state_value["conversation_messages"] = conversation_messages + updated_state = auto_save_conversation(message, response, state_value, auto_save_enabled) + updated_state["conversation_messages"] = conversation_messages # Safely update history if new_history: @@ -534,24 +723,43 @@ def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_ else: new_history = [(message, response)] + # Get the current rating and update display + conversation_id = updated_state.get("conversation_id") + if conversation_id: + rating = get_conversation_rating(conversation_id) + if rating is not None: + rating_display_update = gr.update(value=f"**Current Rating:** {rating} star(s)", visible=True) + rating_input_update = gr.update(value=str(rating), visible=True) + else: + rating_display_update = gr.update(value="**Current Rating:** Not Rated", visible=True) + rating_input_update = gr.update(value=None, visible=True) + else: + rating_display_update = gr.update(value="", visible=False) + rating_input_update = gr.update(value=None, visible=False) + gr.Info("Response generated successfully") logging.info("rag_qa_chat_wrapper completed successfully") - yield new_history, "", gr.update(visible=False), state_value # Include state_value in outputs + yield new_history, "", gr.update( + visible=False), updated_state, rating_display_update, rating_input_update + except ValueError as e: logging.error(f"Input error in rag_qa_chat_wrapper: {str(e)}") gr.Error(f"Input error: {str(e)}") - yield history, "", gr.update(visible=False), state_value + yield history, "", gr.update(visible=False), state_value, gr.update(visible=False), gr.update( + visible=False) except DatabaseError as e: logging.error(f"Database error in rag_qa_chat_wrapper: {str(e)}") gr.Error(f"Database error: {str(e)}") - yield history, "", gr.update(visible=False), state_value + yield history, "", gr.update(visible=False), state_value, gr.update(visible=False), gr.update( + visible=False) except Exception as e: logging.error(f"Unexpected error in rag_qa_chat_wrapper: {e}", exc_info=True) gr.Error("An unexpected error occurred. Please try again later.") - yield history, "", gr.update(visible=False), state_value + yield history, "", gr.update(visible=False), state_value, gr.update(visible=False), gr.update( + visible=False) def clear_chat_history(): - return [], "" + return [], "", gr.update(value="", visible=False), gr.update(value=None, visible=False) submit.click( rag_qa_chat_wrapper, @@ -568,14 +776,17 @@ def clear_chat_history(): use_query_rewriting, state, keywords_input, - top_k_input + top_k_input, + use_re_ranking, + db_choice, + auto_save_checkbox ], - outputs=[chatbot, msg, loading_indicator, state], + outputs=[chatbot, msg, loading_indicator, state, rating_display, rating_input], ) clear_chat.click( clear_chat_history, - outputs=[chatbot, msg] + outputs=[chatbot, msg, rating_display, rating_input] ) return ( @@ -608,7 +819,8 @@ def create_rag_qa_notes_management_tab(): with gr.Row(): with gr.Column(scale=1): # Search Notes - search_notes_input = gr.Textbox(label="Search Notes by Keywords") + search_notes_title = gr.Textbox(label="Search Notes by Title") + search_notes_by_keyword = gr.Textbox(label="Search Notes by Keywords") search_notes_button = gr.Button("Search Notes") notes_list = gr.Dropdown(label="Notes", choices=[]) @@ -617,24 +829,34 @@ def create_rag_qa_notes_management_tab(): delete_note_button = gr.Button("Delete Note") note_title_input = gr.Textbox(label="Note Title") note_content_input = gr.TextArea(label="Note Content", lines=20) - note_keywords_input = gr.Textbox(label="Note Keywords (comma-separated)") + note_keywords_input = gr.Textbox(label="Note Keywords (comma-separated)", value="default_note_keyword") save_note_button = gr.Button("Save Note") create_new_note_button = gr.Button("Create New Note") status_message = gr.HTML() # Function Definitions - def search_notes(keywords): + def search_notes(search_notes_title, keywords): if keywords: keywords_list = [kw.strip() for kw in keywords.split(',')] notes_data, total_pages, total_count = get_notes_by_keywords(keywords_list) - choices = [f"{title} (ID: {note_id})" for note_id, title, content, timestamp in notes_data] - return gr.update(choices=choices) + choices = [f"Note {note_id} - {title} ({timestamp})" for + note_id, title, content, timestamp, conversation_id in notes_data] + return gr.update(choices=choices, label=f"Found {total_count} notes") + elif search_notes_title: + notes_data, total_pages, total_count = search_notes_titles(search_notes_title) + choices = [f"Note {note_id} - {title} ({timestamp})" for + note_id, title, content, timestamp, conversation_id in notes_data] + return gr.update(choices=choices, label=f"Found {total_count} notes") else: - return gr.update(choices=[]) + # This will now return all notes, ordered by timestamp + notes_data, total_pages, total_count = search_notes_titles("") + choices = [f"Note {note_id} - {title} ({timestamp})" for + note_id, title, content, timestamp, conversation_id in notes_data] + return gr.update(choices=choices, label=f"All notes ({total_count} total)") search_notes_button.click( search_notes, - inputs=[search_notes_input], + inputs=[search_notes_title, search_notes_by_keyword], outputs=[notes_list] ) @@ -698,7 +920,7 @@ def delete_selected_note(state_value): # Reset state state_value["selected_note_id"] = None # Update notes list - updated_notes = search_notes("") + updated_notes = search_notes("", "") return updated_notes, gr.update(value="Note deleted successfully."), state_value else: return gr.update(), gr.update(value="No note selected."), state_value @@ -736,7 +958,20 @@ def create_rag_qa_chat_management_tab(): with gr.Row(): with gr.Column(scale=1): # Search Conversations - search_conversations_input = gr.Textbox(label="Search Conversations by Keywords") + with gr.Group(): + gr.Markdown("## Search Conversations") + title_search = gr.Textbox( + label="Search by Title", + placeholder="Enter title to search..." + ) + content_search = gr.Textbox( + label="Search in Chat Content", + placeholder="Enter text to search in messages..." + ) + keyword_search = gr.Textbox( + label="Filter by Keywords (comma-separated)", + placeholder="keyword1, keyword2, ..." + ) search_conversations_button = gr.Button("Search Conversations") conversations_list = gr.Dropdown(label="Conversations", choices=[]) new_conversation_button = gr.Button("New Conversation") @@ -750,26 +985,40 @@ def create_rag_qa_chat_management_tab(): status_message = gr.HTML() # Function Definitions - def search_conversations(keywords): - if keywords: - keywords_list = [kw.strip() for kw in keywords.split(',')] - conversations, total_pages, total_count = search_conversations_by_keywords(keywords_list) - else: - conversations, total_pages, total_count = get_all_conversations() + def search_conversations(title_query, content_query, keywords): + try: + # Parse keywords if provided + keywords_list = None + if keywords and keywords.strip(): + keywords_list = [kw.strip() for kw in keywords.split(',')] + + # Search using existing search_conversations_by_keywords function with all criteria + results, total_pages, total_count = search_conversations_by_keywords( + keywords=keywords_list, + title_query=title_query if title_query.strip() else None, + content_query=content_query if content_query.strip() else None + ) - # Build choices as list of titles (ensure uniqueness) - choices = [] - mapping = {} - for conversation_id, title in conversations: - display_title = f"{title} (ID: {conversation_id[:8]})" - choices.append(display_title) - mapping[display_title] = conversation_id + # Build choices as list of titles (ensure uniqueness) + choices = [] + mapping = {} + for conv in results: + conversation_id = conv['conversation_id'] + title = conv['title'] + display_title = f"{title} (ID: {conversation_id[:8]})" + choices.append(display_title) + mapping[display_title] = conversation_id + + return gr.update(choices=choices), mapping - return gr.update(choices=choices), mapping + except Exception as e: + logging.error(f"Error in search_conversations: {str(e)}") + return gr.update(choices=[]), {} + # Update the search button click event search_conversations_button.click( search_conversations, - inputs=[search_conversations_input], + inputs=[title_search, content_search, keyword_search], outputs=[conversations_list, conversation_mapping] ) @@ -926,19 +1175,18 @@ def create_new_conversation(state_value, mapping): ] ) - def delete_messages_in_conversation(conversation_id): - """Helper function to delete all messages in a conversation.""" + def delete_messages_in_conversation_wrapper(conversation_id): + """Wrapper function to delete all messages in a conversation.""" try: - execute_query("DELETE FROM rag_qa_chats WHERE conversation_id = ?", (conversation_id,)) + delete_messages_in_conversation(conversation_id) logging.info(f"Messages in conversation '{conversation_id}' deleted successfully.") except Exception as e: logging.error(f"Error deleting messages in conversation '{conversation_id}': {e}") raise - def get_conversation_title(conversation_id): + def get_conversation_title_wrapper(conversation_id): """Helper function to get the conversation title.""" - query = "SELECT title FROM conversation_metadata WHERE conversation_id = ?" - result = execute_query(query, (conversation_id,)) + result = get_conversation_title(conversation_id) if result: return result[0][0] else: @@ -1068,19 +1316,6 @@ def export_data_function(export_option, selected_conversations, selected_notes): ) - - -def update_conversation_title(conversation_id, new_title): - """Update the title of a conversation.""" - try: - query = "UPDATE conversation_metadata SET title = ? WHERE conversation_id = ?" - execute_query(query, (new_title, conversation_id)) - logging.info(f"Conversation '{conversation_id}' title updated to '{new_title}'") - except Exception as e: - logging.error(f"Error updating conversation title: {e}") - raise - - def convert_file_to_text(file_path): """Convert various file types to plain text.""" file_extension = os.path.splitext(file_path)[1].lower() diff --git a/App_Function_Libraries/Gradio_UI/Re_summarize_tab.py b/App_Function_Libraries/Gradio_UI/Re_summarize_tab.py index f0181feda..ca9f33170 100644 --- a/App_Function_Libraries/Gradio_UI/Re_summarize_tab.py +++ b/App_Function_Libraries/Gradio_UI/Re_summarize_tab.py @@ -10,16 +10,13 @@ # # Local Imports from App_Function_Libraries.Chunk_Lib import improved_chunking_process -from App_Function_Libraries.DB.DB_Manager import update_media_content, load_preset_prompts +from App_Function_Libraries.DB.DB_Manager import update_media_content, list_prompts from App_Function_Libraries.Gradio_UI.Chat_ui import update_user_prompt from App_Function_Libraries.Gradio_UI.Gradio_Shared import fetch_item_details, fetch_items_by_keyword, \ fetch_items_by_content, fetch_items_by_title_or_url from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_chunk from App_Function_Libraries.Utils.Utils import load_comprehensive_config, default_api_endpoint, global_api_endpoints, \ format_api_name - - -# # ###################################################################################################################### # @@ -36,6 +33,10 @@ def create_resummary_tab(): except Exception as e: logging.error(f"Error setting default API endpoint: {str(e)}") default_value = None + + # Get initial prompts for first page + initial_prompts, total_pages, current_page = list_prompts(page=1, per_page=20) + with gr.TabItem("Re-Summarize", visible=True): gr.Markdown("# Re-Summarize Existing Content") with gr.Row(): @@ -48,7 +49,6 @@ def create_resummary_tab(): item_mapping = gr.State({}) with gr.Row(): - # Refactored API selection dropdown api_name_input = gr.Dropdown( choices=["None"] + [format_api_name(api) for api in global_api_endpoints], value=default_value, @@ -70,9 +70,17 @@ def create_resummary_tab(): preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", value=False, visible=True) + + # Add pagination controls for preset prompts + with gr.Row(visible=False) as preset_prompt_controls: + prev_page = gr.Button("Previous") + current_page_text = gr.Text(f"Page {current_page} of {total_pages}") + next_page = gr.Button("Next") + current_page_state = gr.State(value=1) + with gr.Row(): preset_prompt = gr.Dropdown(label="Select Preset Prompt", - choices=load_preset_prompts(), + choices=initial_prompts, visible=False) with gr.Row(): custom_prompt_input = gr.Textbox(label="Custom Prompt", @@ -101,6 +109,15 @@ def create_resummary_tab(): lines=3, visible=False) + def update_prompt_page(direction, current_page_val): + new_page = max(1, min(total_pages, current_page_val + direction)) + prompts, _, _ = list_prompts(page=new_page, per_page=10) + return ( + gr.update(choices=prompts), + gr.update(value=f"Page {new_page} of {total_pages}"), + new_page + ) + def update_prompts(preset_name): prompts = update_user_prompt(preset_name) return ( @@ -108,6 +125,19 @@ def update_prompts(preset_name): gr.update(value=prompts["system_prompt"], visible=True) ) + # Connect pagination buttons + prev_page.click( + lambda x: update_prompt_page(-1, x), + inputs=[current_page_state], + outputs=[preset_prompt, current_page_text, current_page_state] + ) + + next_page.click( + lambda x: update_prompt_page(1, x), + inputs=[current_page_state], + outputs=[preset_prompt, current_page_text, current_page_state] + ) + preset_prompt.change( update_prompts, inputs=preset_prompt, @@ -124,9 +154,9 @@ def update_prompts(preset_name): outputs=[custom_prompt_input, system_prompt_input] ) preset_prompt_checkbox.change( - fn=lambda x: gr.update(visible=x), + fn=lambda x: (gr.update(visible=x), gr.update(visible=x)), inputs=[preset_prompt_checkbox], - outputs=[preset_prompt] + outputs=[preset_prompt, preset_prompt_controls] ) # Connect the UI elements @@ -155,7 +185,12 @@ def update_prompts(preset_name): outputs=result_output ) - return search_query_input, search_type_input, search_button, items_output, item_mapping, api_name_input, api_key_input, chunking_options_checkbox, chunking_options_box, chunk_method, max_chunk_size, chunk_overlap, custom_prompt_checkbox, custom_prompt_input, resummarize_button, result_output + return ( + search_query_input, search_type_input, search_button, items_output, + item_mapping, api_name_input, api_key_input, chunking_options_checkbox, + chunking_options_box, chunk_method, max_chunk_size, chunk_overlap, + custom_prompt_checkbox, custom_prompt_input, resummarize_button, result_output + ) def update_resummarize_dropdown(search_query, search_type): diff --git a/App_Function_Libraries/Gradio_UI/Search_Tab.py b/App_Function_Libraries/Gradio_UI/Search_Tab.py index 2ad075b81..64c9527b5 100644 --- a/App_Function_Libraries/Gradio_UI/Search_Tab.py +++ b/App_Function_Libraries/Gradio_UI/Search_Tab.py @@ -11,8 +11,8 @@ # # Local Imports from App_Function_Libraries.DB.DB_Manager import view_database, search_and_display_items, get_all_document_versions, \ - fetch_item_details_single, fetch_paginated_data, fetch_item_details, get_latest_transcription -from App_Function_Libraries.DB.SQLite_DB import search_prompts, get_document_version + fetch_item_details_single, fetch_paginated_data, fetch_item_details, get_latest_transcription, search_prompts, \ + get_document_version from App_Function_Libraries.Gradio_UI.Gradio_Shared import update_dropdown, update_detailed_view from App_Function_Libraries.Utils.Utils import get_database_path, format_text_with_line_breaks # @@ -80,8 +80,8 @@ def format_as_html(content, title): """ def create_search_tab(): - with gr.TabItem("Search / Detailed View", visible=True): - gr.Markdown("# Search across all ingested items in the Database") + with gr.TabItem("Media DB Search / Detailed View", visible=True): + gr.Markdown("# Search across all ingested items in the Media Database") with gr.Row(): with gr.Column(scale=1): gr.Markdown("by Title / URL / Keyword / or Content via SQLite Full-Text-Search") @@ -150,8 +150,8 @@ def display_search_results(query): def create_search_summaries_tab(): - with gr.TabItem("Search/View Title+Summary", visible=True): - gr.Markdown("# Search across all ingested items in the Database and review their summaries") + with gr.TabItem("Media DB Search/View Title+Summary", visible=True): + gr.Markdown("# Search across all ingested items in the Media Database and review their summaries") gr.Markdown("Search by Title / URL / Keyword / or Content via SQLite Full-Text-Search") with gr.Row(): with gr.Column(): diff --git a/App_Function_Libraries/Gradio_UI/Video_transcription_tab.py b/App_Function_Libraries/Gradio_UI/Video_transcription_tab.py index 20c5bc90d..3d03af228 100644 --- a/App_Function_Libraries/Gradio_UI/Video_transcription_tab.py +++ b/App_Function_Libraries/Gradio_UI/Video_transcription_tab.py @@ -10,10 +10,12 @@ # External Imports import gradio as gr import yt_dlp + +from App_Function_Libraries.Chunk_Lib import improved_chunking_process # # Local Imports -from App_Function_Libraries.DB.DB_Manager import load_preset_prompts, add_media_to_database, \ - check_media_and_whisper_model, check_existing_media, update_media_content_with_version +from App_Function_Libraries.DB.DB_Manager import add_media_to_database, \ + check_media_and_whisper_model, check_existing_media, update_media_content_with_version, list_prompts from App_Function_Libraries.Gradio_UI.Gradio_Shared import whisper_models, update_user_prompt from App_Function_Libraries.Gradio_UI.Gradio_Shared import error_handler from App_Function_Libraries.Summarization.Summarization_General_Lib import perform_transcription, perform_summarization, \ @@ -65,15 +67,20 @@ def create_video_transcription_tab(): preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", value=False, visible=True) + + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + with gr.Row(): + # Add pagination controls preset_prompt = gr.Dropdown(label="Select Preset Prompt", - choices=load_preset_prompts(), + choices=[], visible=False) with gr.Row(): - custom_prompt_input = gr.Textbox(label="Custom Prompt", - placeholder="Enter custom prompt here", - lines=3, - visible=False) + prev_page_button = gr.Button("Previous Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + next_page_button = gr.Button("Next Page", visible=False) with gr.Row(): system_prompt_input = gr.Textbox(label="System Prompt", value="""You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] @@ -96,22 +103,75 @@ def create_video_transcription_tab(): lines=3, visible=False, interactive=True) + with gr.Row(): + custom_prompt_input = gr.Textbox(label="Custom Prompt", + placeholder="Enter custom prompt here", + lines=3, + visible=False) + custom_prompt_checkbox.change( - fn=lambda x: (gr.update(visible=x), gr.update(visible=x)), + fn=lambda x: (gr.update(visible=x, interactive=x), gr.update(visible=x, interactive=x)), inputs=[custom_prompt_checkbox], outputs=[custom_prompt_input, system_prompt_input] ) + + def on_preset_prompt_checkbox_change(is_checked): + if is_checked: + prompts, total_pages, current_page = list_prompts(page=1, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(visible=True, interactive=True, choices=prompts), # preset_prompt + gr.update(visible=True), # prev_page_button + gr.update(visible=True), # next_page_button + gr.update(value=page_display_text, visible=True), # page_display + current_page, # current_page_state + total_pages # total_pages_state + ) + else: + return ( + gr.update(visible=False, interactive=False), # preset_prompt + gr.update(visible=False), # prev_page_button + gr.update(visible=False), # next_page_button + gr.update(visible=False), # page_display + 1, # current_page_state + 1 # total_pages_state + ) + preset_prompt_checkbox.change( - fn=lambda x: gr.update(visible=x), + fn=on_preset_prompt_checkbox_change, inputs=[preset_prompt_checkbox], - outputs=[preset_prompt] + outputs=[preset_prompt, prev_page_button, next_page_button, page_display, current_page_state, total_pages_state] + ) + + def on_prev_page_click(current_page, total_pages): + new_page = max(current_page - 1, 1) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] ) def update_prompts(preset_name): prompts = update_user_prompt(preset_name) return ( - gr.update(value=prompts["user_prompt"], visible=True), - gr.update(value=prompts["system_prompt"], visible=True) + gr.update(value=prompts["user_prompt"], visible=True, interactive=True), + gr.update(value=prompts["system_prompt"], visible=True, interactive=True) ) preset_prompt.change( @@ -209,7 +269,6 @@ def process_videos_with_error_handling(inputs, start_time, end_time, diarize, va try: # Start overall processing timer proc_start_time = datetime.now() - # FIXME - summarize_recursively is not being used... logging.info("Entering process_videos_with_error_handling") logging.info(f"Received inputs: {inputs}") @@ -261,7 +320,6 @@ def process_videos_with_error_handling(inputs, start_time, end_time, diarize, va all_summaries = "" # Start timing - # FIXME - utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC). start_proc = datetime.now() for i in range(0, len(all_inputs), batch_size): @@ -323,7 +381,7 @@ def process_videos_with_error_handling(inputs, start_time, end_time, diarize, va input_item, 2, whisper_model, custom_prompt, start_seconds, api_name, api_key, - vad_use, False, False, False, 0.01, None, keywords, None, diarize, + vad_use, False, False, summarize_recursively, 0.01, None, keywords, None, diarize, end_time=end_seconds, include_timestamps=timestamp_option, metadata=video_metadata, @@ -782,7 +840,54 @@ def process_url_with_metadata(input_item, num_speakers, whisper_model, custom_pr # API key resolution handled at base of function if none provided api_key = api_key if api_key else None logging.info(f"process_url_with_metadata: Starting summarization with {api_name}...") - summary_text = perform_summarization(api_name, full_text_with_metadata, custom_prompt, api_key) + + # Perform Chunking if enabled + # FIXME - Setup a proper prompt for Recursive Summarization + if use_chunking: + logging.info("process_url_with_metadata: Chunking enabled. Starting chunking...") + chunked_texts = improved_chunking_process(full_text_with_metadata, chunk_options) + + if chunked_texts is None: + logging.warning("Chunking failed, falling back to full text summarization") + summary_text = perform_summarization(api_name, full_text_with_metadata, custom_prompt, + api_key) + else: + logging.debug( + f"process_url_with_metadata: Chunking completed. Processing {len(chunked_texts)} chunks...") + summaries = [] + + if rolling_summarization: + # Perform recursive summarization on each chunk + for chunk in chunked_texts: + chunk_summary = perform_summarization(api_name, chunk['text'], custom_prompt, + api_key) + if chunk_summary: + summaries.append( + f"Chunk {chunk['metadata']['chunk_index']}/{chunk['metadata']['total_chunks']}: {chunk_summary}") + summary_text = "\n\n".join(summaries) + else: + logging.error("All chunk summarizations failed") + summary_text = None + + for chunk in chunked_texts: + # Perform Non-recursive summarization on each chunk + chunk_summary = perform_summarization(api_name, chunk['text'], custom_prompt, + api_key) + if chunk_summary: + summaries.append( + f"Chunk {chunk['metadata']['chunk_index']}/{chunk['metadata']['total_chunks']}: {chunk_summary}") + + if summaries: + summary_text = "\n\n".join(summaries) + logging.info(f"Successfully summarized {len(summaries)} chunks") + else: + logging.error("All chunk summarizations failed") + summary_text = None + else: + # Regular summarization without chunking + summary_text = perform_summarization(api_name, full_text_with_metadata, custom_prompt, + api_key) if api_name else None + if summary_text is None: logging.error("Summarization failed.") return None, None, None, None, None, None diff --git a/App_Function_Libraries/Gradio_UI/View_DB_Items_tab.py b/App_Function_Libraries/Gradio_UI/View_DB_Items_tab.py index 85ce9dfe6..021c492ee 100644 --- a/App_Function_Libraries/Gradio_UI/View_DB_Items_tab.py +++ b/App_Function_Libraries/Gradio_UI/View_DB_Items_tab.py @@ -3,14 +3,15 @@ # # Imports import html +import logging + # # External Imports import gradio as gr # # Local Imports from App_Function_Libraries.DB.DB_Manager import view_database, get_all_document_versions, \ - fetch_paginated_data, fetch_item_details, get_latest_transcription, list_prompts, fetch_prompt_details, \ - load_preset_prompts + fetch_paginated_data, fetch_item_details, get_latest_transcription, list_prompts, fetch_prompt_details from App_Function_Libraries.DB.RAG_QA_Chat_DB import get_keywords_for_note, search_conversations_by_keywords, \ get_notes_by_keywords, get_keywords_for_conversation, get_db_connection, get_all_conversations, load_chat_history, \ get_notes @@ -22,117 +23,6 @@ # # Functions -def create_prompt_view_tab(): - with gr.TabItem("View Prompt Database", visible=True): - gr.Markdown("# View Prompt Database Entries") - with gr.Row(): - with gr.Column(): - entries_per_page = gr.Dropdown(choices=[10, 20, 50, 100], label="Entries per Page", value=10) - page_number = gr.Number(value=1, label="Page Number", precision=0) - view_button = gr.Button("View Page") - next_page_button = gr.Button("Next Page") - previous_page_button = gr.Button("Previous Page") - pagination_info = gr.Textbox(label="Pagination Info", interactive=False) - prompt_selector = gr.Dropdown(label="Select Prompt to View", choices=[]) - with gr.Column(): - results_table = gr.HTML() - selected_prompt_display = gr.HTML() - - def view_database(page, entries_per_page): - try: - prompts, total_pages, current_page = list_prompts(page, entries_per_page) - - table_html = "" - table_html += "" - prompt_choices = [] - for prompt_name in prompts: - details = fetch_prompt_details(prompt_name) - if details: - title, _, _, _, _, _ = details - author = "Unknown" # Assuming author is not stored in the current schema - table_html += f"" - prompt_choices.append((title, title)) # Using title as both label and value - table_html += "
TitleAuthor
{html.escape(title)}{html.escape(author)}
" - - total_prompts = len(load_preset_prompts()) # This might be inefficient for large datasets - pagination = f"Page {current_page} of {total_pages} (Total prompts: {total_prompts})" - - return table_html, pagination, total_pages, prompt_choices - except Exception as e: - return f"

Error fetching prompts: {e}

", "Error", 0, [] - - def update_page(page, entries_per_page): - results, pagination, total_pages, prompt_choices = view_database(page, entries_per_page) - next_disabled = page >= total_pages - prev_disabled = page <= 1 - return results, pagination, page, gr.update(interactive=not next_disabled), gr.update( - interactive=not prev_disabled), gr.update(choices=prompt_choices) - - def go_to_next_page(current_page, entries_per_page): - next_page = current_page + 1 - return update_page(next_page, entries_per_page) - - def go_to_previous_page(current_page, entries_per_page): - previous_page = max(1, current_page - 1) - return update_page(previous_page, entries_per_page) - - def display_selected_prompt(prompt_name): - details = fetch_prompt_details(prompt_name) - if details: - title, author, description, system_prompt, user_prompt, keywords = details - # Handle None values by converting them to empty strings - description = description or "" - system_prompt = system_prompt or "" - user_prompt = user_prompt or "" - author = author or "Unknown" - keywords = keywords or "" - - html_content = f""" -
-

{html.escape(title)}

by {html.escape(author)}

-

Description: {html.escape(description)}

-
- System Prompt: -
{html.escape(system_prompt)}
-
-
- User Prompt: -
{html.escape(user_prompt)}
-
-

Keywords: {html.escape(keywords)}

-
- """ - return html_content - else: - return "

Prompt not found.

" - - view_button.click( - fn=update_page, - inputs=[page_number, entries_per_page], - outputs=[results_table, pagination_info, page_number, next_page_button, previous_page_button, - prompt_selector] - ) - - next_page_button.click( - fn=go_to_next_page, - inputs=[page_number, entries_per_page], - outputs=[results_table, pagination_info, page_number, next_page_button, previous_page_button, - prompt_selector] - ) - - previous_page_button.click( - fn=go_to_previous_page, - inputs=[page_number, entries_per_page], - outputs=[results_table, pagination_info, page_number, next_page_button, previous_page_button, - prompt_selector] - ) - - prompt_selector.change( - fn=display_selected_prompt, - inputs=[prompt_selector], - outputs=[selected_prompt_display] - ) - def format_as_html(content, title): escaped_content = html.escape(content) formatted_content = escaped_content.replace('\n', '
') @@ -491,18 +381,24 @@ def format_conversations_table(conversations): Title Keywords Notes + Rating """ - for conv_id, title in conversations: + for conversation in conversations: + conv_id = conversation['conversation_id'] + title = conversation['title'] + rating = conversation.get('rating', '') # Use get() to handle cases where rating might not exist + keywords = get_keywords_for_conversation(conv_id) notes = get_notes(conv_id) table_html += f""" - {html.escape(title)} + {html.escape(str(title))} {html.escape(', '.join(keywords))} {len(notes)} note(s) + {html.escape(str(rating))} """ table_html += "" @@ -586,8 +482,12 @@ def update_page(page, entries_per_page): conversations, total_pages, total_count = get_all_conversations(page, entries_per_page) pagination = f"Page {page} of {total_pages} (Total conversations: {total_count})" - choices = [f"{title} (ID: {conv_id})" for conv_id, title in conversations] - new_item_mapping = {f"{title} (ID: {conv_id})": conv_id for conv_id, title in conversations} + # Handle the dictionary structure correctly + choices = [f"{conv['title']} (ID: {conv['conversation_id']})" for conv in conversations] + new_item_mapping = { + f"{conv['title']} (ID: {conv['conversation_id']})": conv['conversation_id'] + for conv in conversations + } next_disabled = page >= total_pages prev_disabled = page <= 1 @@ -605,6 +505,7 @@ def update_page(page, entries_per_page): new_item_mapping ) except Exception as e: + logging.error(f"Error in update_page: {str(e)}", exc_info=True) return ( gr.update(choices=[], value=None), f"Error: {str(e)}", @@ -674,8 +575,18 @@ def display_conversation_details(selected_item, item_mapping): view_button.click( fn=update_page, inputs=[page_number, entries_per_page], - outputs=[items_output, pagination_info, page_number, next_page_button, previous_page_button, - conversation_title, keywords_output, chat_history_output, notes_output, item_mapping] + outputs=[ + items_output, + pagination_info, + page_number, + next_page_button, + previous_page_button, + conversation_title, + keywords_output, + chat_history_output, + notes_output, + item_mapping + ] ) next_page_button.click( @@ -792,7 +703,7 @@ def format_notes_html(notes_data): return html_content def view_items(keywords, page, entries_per_page): - if not keywords: + if not keywords or (isinstance(keywords, list) and len(keywords) == 0): return ( "

Please select at least one keyword.

", "

Please select at least one keyword.

", @@ -802,14 +713,17 @@ def view_items(keywords, page, entries_per_page): ) try: + # Ensure keywords is a list + keywords_list = keywords if isinstance(keywords, list) else [keywords] + # Get conversations for selected keywords conversations, conv_total_pages, conv_count = search_conversations_by_keywords( - keywords, page, entries_per_page + keywords_list, page, entries_per_page ) # Get notes for selected keywords notes, notes_total_pages, notes_count = get_notes_by_keywords( - keywords, page, entries_per_page + keywords_list, page, entries_per_page ) # Format results as HTML @@ -833,6 +747,7 @@ def view_items(keywords, page, entries_per_page): gr.update(interactive=not prev_disabled) ) except Exception as e: + logging.error(f"Error in view_items: {str(e)}") return ( f"

Error: {str(e)}

", f"

Error: {str(e)}

", diff --git a/App_Function_Libraries/Gradio_UI/Website_scraping_tab.py b/App_Function_Libraries/Gradio_UI/Website_scraping_tab.py index 80b19ba9f..5204547e0 100644 --- a/App_Function_Libraries/Gradio_UI/Website_scraping_tab.py +++ b/App_Function_Libraries/Gradio_UI/Website_scraping_tab.py @@ -22,7 +22,7 @@ # Local Imports from App_Function_Libraries.Web_Scraping.Article_Extractor_Lib import scrape_from_sitemap, scrape_by_url_level, \ scrape_article, collect_bookmarks, scrape_and_summarize_multiple, collect_urls_from_file -from App_Function_Libraries.DB.DB_Manager import load_preset_prompts +from App_Function_Libraries.DB.DB_Manager import list_prompts from App_Function_Libraries.Gradio_UI.Chat_ui import update_user_prompt from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize @@ -314,12 +314,22 @@ def create_website_scraping_tab(): preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", value=False, visible=True) with gr.Row(): temp_slider = gr.Slider(0.1, 2.0, 0.7, label="Temperature") + + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) with gr.Row(): + # Add pagination controls preset_prompt = gr.Dropdown( label="Select Preset Prompt", - choices=load_preset_prompts(), + choices=[], visible=False ) + with gr.Row(): + prev_page_button = gr.Button("Previous Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + next_page_button = gr.Button("Next Page", visible=False) + with gr.Row(): website_custom_prompt_input = gr.Textbox( label="Custom Prompt", @@ -421,10 +431,57 @@ def update_ui_for_scrape_method(method): inputs=[custom_prompt_checkbox], outputs=[website_custom_prompt_input, system_prompt_input] ) + + def on_preset_prompt_checkbox_change(is_checked): + if is_checked: + prompts, total_pages, current_page = list_prompts(page=1, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(visible=True, interactive=True, choices=prompts), # preset_prompt + gr.update(visible=True), # prev_page_button + gr.update(visible=True), # next_page_button + gr.update(value=page_display_text, visible=True), # page_display + current_page, # current_page_state + total_pages # total_pages_state + ) + else: + return ( + gr.update(visible=False, interactive=False), # preset_prompt + gr.update(visible=False), # prev_page_button + gr.update(visible=False), # next_page_button + gr.update(visible=False), # page_display + 1, # current_page_state + 1 # total_pages_state + ) + preset_prompt_checkbox.change( - fn=lambda x: gr.update(visible=x), + fn=on_preset_prompt_checkbox_change, inputs=[preset_prompt_checkbox], - outputs=[preset_prompt] + outputs=[preset_prompt, prev_page_button, next_page_button, page_display, current_page_state, total_pages_state] + ) + + def on_prev_page_click(current_page, total_pages): + new_page = max(current_page - 1, 1) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] ) def update_prompts(preset_name): diff --git a/App_Function_Libraries/Gradio_UI/Chat_Workflows.py b/App_Function_Libraries/Gradio_UI/Workflows_tab.py similarity index 96% rename from App_Function_Libraries/Gradio_UI/Chat_Workflows.py rename to App_Function_Libraries/Gradio_UI/Workflows_tab.py index 802df6de9..5c911d290 100644 --- a/App_Function_Libraries/Gradio_UI/Chat_Workflows.py +++ b/App_Function_Libraries/Gradio_UI/Workflows_tab.py @@ -1,5 +1,5 @@ # Chat_Workflows.py -# Description: UI for Chat Workflows +# Description: Gradio UI for Chat Workflows # # Imports import json @@ -9,11 +9,11 @@ # External Imports import gradio as gr # +# Local Imports from App_Function_Libraries.Gradio_UI.Chat_ui import chat_wrapper, search_conversations, \ load_conversation -from App_Function_Libraries.Chat import save_chat_history_to_db_wrapper +from App_Function_Libraries.Chat.Chat_Functions import save_chat_history_to_db_wrapper from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name - # ############################################################################################################ # @@ -74,6 +74,7 @@ def chat_workflows_tab(): clear_btn = gr.Button("Clear Chat") chat_media_name = gr.Textbox(label="Custom Chat Name(optional)") save_btn = gr.Button("Save Chat to Database") + save_status = gr.Textbox(label="Save Status", interactive=False) def update_workflow_ui(workflow_name): if not workflow_name: @@ -164,7 +165,7 @@ def process_workflow_step(message, history, context, workflow_name, api_endpoint save_btn.click( save_chat_history_to_db_wrapper, inputs=[chatbot, conversation_id, media_content, chat_media_name], - outputs=[conversation_id, gr.Textbox(label="Save Status")] + outputs=[conversation_id, save_status] ) search_conversations_btn.click( diff --git a/App_Function_Libraries/Gradio_UI/Writing_tab.py b/App_Function_Libraries/Gradio_UI/Writing_tab.py index 37f5a8fdd..306513314 100644 --- a/App_Function_Libraries/Gradio_UI/Writing_tab.py +++ b/App_Function_Libraries/Gradio_UI/Writing_tab.py @@ -46,17 +46,17 @@ def grammar_style_check(input_text, custom_prompt, api_name, api_key, system_pro def create_grammar_style_check_tab(): - try: - default_value = None - if default_api_endpoint: - if default_api_endpoint in global_api_endpoints: - default_value = format_api_name(default_api_endpoint) - else: - logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") - except Exception as e: - logging.error(f"Error setting default API endpoint: {str(e)}") - default_value = None with gr.TabItem("Grammar and Style Check", visible=True): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None with gr.Row(): with gr.Column(): gr.Markdown("# Grammar and Style Check") @@ -317,63 +317,63 @@ def create_document_feedback_tab(): with gr.Row(): compare_button = gr.Button("Compare Feedback") - feedback_history = gr.State([]) - - def add_custom_persona(name, description): - updated_choices = persona_dropdown.choices + [name] - persona_prompts[name] = f"As {name}, {description}, provide feedback on the following text:" - return gr.update(choices=updated_choices) - - def update_feedback_history(current_text, persona, feedback): - # Ensure feedback_history.value is initialized and is a list - if feedback_history.value is None: - feedback_history.value = [] - - history = feedback_history.value - - # Append the new entry to the history - history.append({"text": current_text, "persona": persona, "feedback": feedback}) - - # Keep only the last 5 entries in the history - feedback_history.value = history[-10:] - - # Generate and return the updated HTML - return generate_feedback_history_html(feedback_history.value) - - def compare_feedback(text, selected_personas, api_name, api_key): - results = [] - for persona in selected_personas: - feedback = generate_writing_feedback(text, persona, "Overall", api_name, api_key) - results.append(f"### {persona}'s Feedback:\n{feedback}\n\n") - return "\n".join(results) - - add_custom_persona_button.click( - fn=add_custom_persona, - inputs=[custom_persona_name, custom_persona_description], - outputs=persona_dropdown - ) - - get_feedback_button.click( - fn=lambda text, persona, aspect, api_name, api_key: ( - generate_writing_feedback(text, persona, aspect, api_name, api_key), - calculate_readability(text), - update_feedback_history(text, persona, generate_writing_feedback(text, persona, aspect, api_name, api_key)) - ), - inputs=[input_text, persona_dropdown, aspect_dropdown, api_name_input, api_key_input], - outputs=[feedback_output, readability_output, feedback_history_display] - ) - - compare_button.click( - fn=compare_feedback, - inputs=[input_text, compare_personas, api_name_input, api_key_input], - outputs=feedback_output - ) - - generate_prompt_button.click( - fn=generate_writing_prompt, - inputs=[persona_dropdown, api_name_input, api_key_input], - outputs=input_text - ) + feedback_history = gr.State([]) + + def add_custom_persona(name, description): + updated_choices = persona_dropdown.choices + [name] + persona_prompts[name] = f"As {name}, {description}, provide feedback on the following text:" + return gr.update(choices=updated_choices) + + def update_feedback_history(current_text, persona, feedback): + # Ensure feedback_history.value is initialized and is a list + if feedback_history.value is None: + feedback_history.value = [] + + history = feedback_history.value + + # Append the new entry to the history + history.append({"text": current_text, "persona": persona, "feedback": feedback}) + + # Keep only the last 5 entries in the history + feedback_history.value = history[-10:] + + # Generate and return the updated HTML + return generate_feedback_history_html(feedback_history.value) + + def compare_feedback(text, selected_personas, api_name, api_key): + results = [] + for persona in selected_personas: + feedback = generate_writing_feedback(text, persona, "Overall", api_name, api_key) + results.append(f"### {persona}'s Feedback:\n{feedback}\n\n") + return "\n".join(results) + + add_custom_persona_button.click( + fn=add_custom_persona, + inputs=[custom_persona_name, custom_persona_description], + outputs=persona_dropdown + ) + + get_feedback_button.click( + fn=lambda text, persona, aspect, api_name, api_key: ( + generate_writing_feedback(text, persona, aspect, api_name, api_key), + calculate_readability(text), + update_feedback_history(text, persona, generate_writing_feedback(text, persona, aspect, api_name, api_key)) + ), + inputs=[input_text, persona_dropdown, aspect_dropdown, api_name_input, api_key_input], + outputs=[feedback_output, readability_output, feedback_history_display] + ) + + compare_button.click( + fn=compare_feedback, + inputs=[input_text, compare_personas, api_name_input, api_key_input], + outputs=feedback_output + ) + + generate_prompt_button.click( + fn=generate_writing_prompt, + inputs=[persona_dropdown, api_name_input, api_key_input], + outputs=input_text + ) return input_text, feedback_output, readability_output, feedback_history_display diff --git a/App_Function_Libraries/Plaintext/Plaintext_Files.py b/App_Function_Libraries/Plaintext/Plaintext_Files.py index f5038a967..ba3c686db 100644 --- a/App_Function_Libraries/Plaintext/Plaintext_Files.py +++ b/App_Function_Libraries/Plaintext/Plaintext_Files.py @@ -2,9 +2,12 @@ # Description: This file contains functions for reading and writing plaintext files. # # Import necessary libraries +import logging import os import tempfile import zipfile +from datetime import datetime + # # External Imports from docx2txt import docx2txt @@ -12,57 +15,161 @@ # # Local Imports from App_Function_Libraries.Gradio_UI.Import_Functionality import import_data +from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram + + # ####################################################################################################################### # # Function Definitions -def import_plain_text_file(file_path, title, author, keywords, system_prompt, user_prompt, auto_summarize, api_name, - api_key): +def import_plain_text_file(file_path, author, keywords, system_prompt, user_prompt, auto_summarize, api_name, api_key): + """Import a single plain text file.""" try: + log_counter("file_processing_attempt", labels={"file_path": file_path}) + + # Extract title from filename + title = os.path.splitext(os.path.basename(file_path))[0] + # Determine the file type and convert if necessary file_extension = os.path.splitext(file_path)[1].lower() - if file_extension == '.rtf': - with tempfile.NamedTemporaryFile(suffix='.md', delete=False) as temp_file: - convert_file(file_path, 'md', outputfile=temp_file.name) - file_path = temp_file.name - elif file_extension == '.docx': - content = docx2txt.process(file_path) - else: - with open(file_path, 'r', encoding='utf-8') as file: - content = file.read() - - # Process the content - return import_data(content, title, author, keywords, system_prompt, - user_prompt, auto_summarize, api_name, api_key) + + # Get the content based on file type + try: + if file_extension == '.rtf': + with tempfile.NamedTemporaryFile(suffix='.md', delete=False) as temp_file: + convert_file(file_path, 'md', outputfile=temp_file.name) + file_path = temp_file.name + with open(file_path, 'r', encoding='utf-8') as file: + content = file.read() + log_counter("rtf_conversion_success", labels={"file_path": file_path}) + elif file_extension == '.docx': + content = docx2txt.process(file_path) + log_counter("docx_conversion_success", labels={"file_path": file_path}) + else: + with open(file_path, 'r', encoding='utf-8') as file: + content = file.read() + except Exception as e: + logging.error(f"Error reading file content: {str(e)}") + return f"Error reading file content: {str(e)}" + + # Import the content + result = import_data( + content, # Pass the content directly + title, + author, + keywords, + user_prompt, # This is the custom_prompt parameter + None, # No summary - let auto_summarize handle it + auto_summarize, + api_name, + api_key + ) + + log_counter("file_processing_success", labels={"file_path": file_path}) + return result + except Exception as e: - return f"Error processing file: {str(e)}" + logging.exception(f"Error processing file {file_path}") + log_counter("file_processing_error", labels={"file_path": file_path, "error": str(e)}) + return f"Error processing file {os.path.basename(file_path)}: {str(e)}" + -def process_plain_text_zip_file(zip_file, title, author, keywords, system_prompt, user_prompt, auto_summarize, api_name, api_key): +def process_plain_text_zip_file(zip_file, author, keywords, system_prompt, user_prompt, auto_summarize, api_name, api_key): + """Process multiple text files from a zip archive.""" results = [] - with tempfile.TemporaryDirectory() as temp_dir: - with zipfile.ZipFile(zip_file.name, 'r') as zip_ref: - zip_ref.extractall(temp_dir) - - for filename in os.listdir(temp_dir): - if filename.lower().endswith(('.md', '.txt', '.rtf', '.docx')): - file_path = os.path.join(temp_dir, filename) - result = import_plain_text_file(file_path, title, author, keywords, system_prompt, - user_prompt, auto_summarize, api_name, api_key) - results.append(f"File: {filename} - {result}") - - return "\n".join(results) - - -def import_file_handler(file, title, author, keywords, system_prompt, user_prompt, auto_summarize, api_name, api_key): - if file.name.lower().endswith(('.md', '.txt', '.rtf', '.docx')): - return import_plain_text_file(file.name, title, author, keywords, system_prompt, user_prompt, auto_summarize, - api_name, api_key) - elif file.name.lower().endswith('.zip'): - return process_plain_text_zip_file(file, title, author, keywords, system_prompt, user_prompt, auto_summarize, - api_name, api_key) - else: - return "Unsupported file type. Please upload a .md, .txt, .rtf, .docx file or a .zip file containing these file types." + try: + with tempfile.TemporaryDirectory() as temp_dir: + with zipfile.ZipFile(zip_file.name, 'r') as zip_ref: + zip_ref.extractall(temp_dir) + + for filename in os.listdir(temp_dir): + if filename.lower().endswith(('.md', '.txt', '.rtf', '.docx')): + file_path = os.path.join(temp_dir, filename) + result = import_plain_text_file( + file_path=file_path, + author=author, + keywords=keywords, + system_prompt=system_prompt, + user_prompt=user_prompt, + auto_summarize=auto_summarize, + api_name=api_name, + api_key=api_key + ) + results.append(f"📄 {filename}: {result}") + + return "\n\n".join(results) + except Exception as e: + logging.exception(f"Error processing zip file: {str(e)}") + return f"Error processing zip file: {str(e)}" + + + +def import_file_handler(files, author, keywords, system_prompt, user_prompt, auto_summarize, api_name, api_key): + """Handle the import of one or more files, including zip files.""" + try: + if not files: + log_counter("plaintext_import_error", labels={"error": "No files uploaded"}) + return "No files uploaded." + + # Convert single file to list for consistent processing + if not isinstance(files, list): + files = [files] + + results = [] + for file in files: + log_counter("plaintext_import_attempt", labels={"file_name": file.name}) + + start_time = datetime.now() + + if not os.path.exists(file.name): + log_counter("plaintext_import_error", labels={"error": "File not found", "file_name": file.name}) + results.append(f"❌ File not found: {file.name}") + continue + + if file.name.lower().endswith(('.md', '.txt', '.rtf', '.docx')): + result = import_plain_text_file( + file_path=file.name, + author=author, + keywords=keywords, + system_prompt=system_prompt, + user_prompt=user_prompt, + auto_summarize=auto_summarize, + api_name=api_name, + api_key=api_key + ) + log_counter("plaintext_import_success", labels={"file_name": file.name}) + results.append(f"📄 {file.name}: {result}") + + elif file.name.lower().endswith('.zip'): + result = process_plain_text_zip_file( + zip_file=file, + author=author, + keywords=keywords, + system_prompt=system_prompt, + user_prompt=user_prompt, + auto_summarize=auto_summarize, + api_name=api_name, + api_key=api_key + ) + log_counter("zip_import_success", labels={"file_name": file.name}) + results.append(f"📦 {file.name}:\n{result}") + + else: + log_counter("unsupported_file_type", labels={"file_type": file.name.split('.')[-1]}) + results.append(f"❌ Unsupported file type: {file.name}") + continue + + end_time = datetime.now() + processing_time = (end_time - start_time).total_seconds() + log_histogram("plaintext_import_duration", processing_time, labels={"file_name": file.name}) + + return "\n\n".join(results) + + except Exception as e: + logging.exception("Error in import_file_handler") + log_counter("plaintext_import_error", labels={"error": str(e)}) + return f"❌ Error during import: {str(e)}" # # End of Plaintext_Files.py diff --git a/App_Function_Libraries/Prompt_Engineering/Prompt_Engineering.py b/App_Function_Libraries/Prompt_Engineering/Prompt_Engineering.py index d192f55b8..c037eeadb 100644 --- a/App_Function_Libraries/Prompt_Engineering/Prompt_Engineering.py +++ b/App_Function_Libraries/Prompt_Engineering/Prompt_Engineering.py @@ -4,7 +4,7 @@ # Imports import re -from App_Function_Libraries.Chat import chat_api_call +from App_Function_Libraries.Chat.Chat_Functions import chat_api_call # # Local Imports # diff --git a/App_Function_Libraries/Prompt_Engineering/__Init__.py b/App_Function_Libraries/Prompt_Engineering/__Init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/App_Function_Libraries/RAG/ChromaDB_Library.py b/App_Function_Libraries/RAG/ChromaDB_Library.py index 3cec5fd8d..623e5adb4 100644 --- a/App_Function_Libraries/RAG/ChromaDB_Library.py +++ b/App_Function_Libraries/RAG/ChromaDB_Library.py @@ -49,36 +49,37 @@ # Function to preprocess and store all existing content in the database -def preprocess_all_content(database, create_contextualized=True, api_name="gpt-3.5-turbo"): - unprocessed_media = get_unprocessed_media(db=database) - total_media = len(unprocessed_media) - - for index, row in enumerate(unprocessed_media, 1): - media_id, content, media_type, file_name = row - collection_name = f"{media_type}_{media_id}" - - logger.info(f"Processing media {index} of {total_media}: ID {media_id}, Type {media_type}") - - try: - process_and_store_content( - database=database, - content=content, - collection_name=collection_name, - media_id=media_id, - file_name=file_name or f"{media_type}_{media_id}", - create_embeddings=True, - create_contextualized=create_contextualized, - api_name=api_name - ) - - # Mark the media as processed in the database - mark_media_as_processed(database, media_id) - - logger.info(f"Successfully processed media ID {media_id}") - except Exception as e: - logger.error(f"Error processing media ID {media_id}: {str(e)}") - - logger.info("Finished preprocessing all unprocessed content") +# FIXME - Deprecated +# def preprocess_all_content(database, create_contextualized=True, api_name="gpt-3.5-turbo"): +# unprocessed_media = get_unprocessed_media(db=database) +# total_media = len(unprocessed_media) +# +# for index, row in enumerate(unprocessed_media, 1): +# media_id, content, media_type, file_name = row +# collection_name = f"{media_type}_{media_id}" +# +# logger.info(f"Processing media {index} of {total_media}: ID {media_id}, Type {media_type}") +# +# try: +# process_and_store_content( +# database=database, +# content=content, +# collection_name=collection_name, +# media_id=media_id, +# file_name=file_name or f"{media_type}_{media_id}", +# create_embeddings=True, +# create_contextualized=create_contextualized, +# api_name=api_name +# ) +# +# # Mark the media as processed in the database +# mark_media_as_processed(database, media_id) +# +# logger.info(f"Successfully processed media ID {media_id}") +# except Exception as e: +# logger.error(f"Error processing media ID {media_id}: {str(e)}") +# +# logger.info("Finished preprocessing all unprocessed content") def batched(iterable, n): @@ -233,7 +234,10 @@ def store_in_chroma(collection_name: str, texts: List[str], embeddings: Any, ids logging.info(f"Number of embeddings: {len(embeddings)}, Dimension: {embedding_dim}") try: - # Attempt to get or create the collection + # Clean metadata + cleaned_metadatas = [clean_metadata(metadata) for metadata in metadatas] + + # Try to get or create the collection try: collection = chroma_client.get_collection(name=collection_name) logging.info(f"Existing collection '{collection_name}' found") @@ -258,7 +262,7 @@ def store_in_chroma(collection_name: str, texts: List[str], embeddings: Any, ids documents=texts, embeddings=embeddings, ids=ids, - metadatas=metadatas + metadatas=cleaned_metadatas ) logging.info(f"Successfully upserted {len(embeddings)} embeddings") @@ -290,12 +294,19 @@ def vector_search(collection_name: str, query: str, k: int = 10) -> List[Dict[st # Fetch a sample of embeddings to check metadata sample_results = collection.get(limit=10, include=["metadatas"]) - if not sample_results['metadatas']: - raise ValueError("No metadata found in the collection") + if not sample_results.get('metadatas') or not any(sample_results['metadatas']): + logging.warning(f"No metadata found in the collection '{collection_name}'. Skipping this collection.") + return [] # Check if all embeddings use the same model and provider - embedding_models = [metadata.get('embedding_model') for metadata in sample_results['metadatas'] if metadata.get('embedding_model')] - embedding_providers = [metadata.get('embedding_provider') for metadata in sample_results['metadatas'] if metadata.get('embedding_provider')] + embedding_models = [ + metadata.get('embedding_model') for metadata in sample_results['metadatas'] + if metadata and metadata.get('embedding_model') + ] + embedding_providers = [ + metadata.get('embedding_provider') for metadata in sample_results['metadatas'] + if metadata and metadata.get('embedding_provider') + ] if not embedding_models or not embedding_providers: raise ValueError("Embedding model or provider information not found in metadata") @@ -319,13 +330,13 @@ def vector_search(collection_name: str, query: str, k: int = 10) -> List[Dict[st ) if not results['documents'][0]: - logging.warning("No results found for the query") + logging.warning(f"No results found for the query in collection '{collection_name}'.") return [] return [{"content": doc, "metadata": meta} for doc, meta in zip(results['documents'][0], results['metadatas'][0])] except Exception as e: - logging.error(f"Error in vector_search: {str(e)}", exc_info=True) - raise + logging.error(f"Error in vector_search for collection '{collection_name}': {str(e)}", exc_info=True) + return [] def schedule_embedding(media_id: int, content: str, media_name: str): @@ -350,6 +361,21 @@ def schedule_embedding(media_id: int, content: str, media_name: str): logging.error(f"Error scheduling embedding for media_id {media_id}: {str(e)}") +def clean_metadata(metadata: Dict[str, Any]) -> Dict[str, Any]: + """Clean metadata by removing None values and converting to appropriate types""" + cleaned = {} + for key, value in metadata.items(): + if value is not None: # Skip None values + if isinstance(value, (str, int, float, bool)): + cleaned[key] = value + elif isinstance(value, (np.int32, np.int64)): + cleaned[key] = int(value) + elif isinstance(value, (np.float32, np.float64)): + cleaned[key] = float(value) + else: + cleaned[key] = str(value) # Convert other types to string + return cleaned + # Function to process content, create chunks, embeddings, and store in ChromaDB and SQLite # def process_and_store_content(content: str, collection_name: str, media_id: int): # # Process the content into chunks diff --git a/App_Function_Libraries/RAG/Embeddings_Create.py b/App_Function_Libraries/RAG/Embeddings_Create.py index 0f57abbed..04ca6b09a 100644 --- a/App_Function_Libraries/RAG/Embeddings_Create.py +++ b/App_Function_Libraries/RAG/Embeddings_Create.py @@ -25,8 +25,6 @@ # # Functions: -# FIXME - Version 2 - # Load configuration loaded_config = load_comprehensive_config() embedding_provider = loaded_config['Embeddings']['embedding_provider'] @@ -331,177 +329,6 @@ def create_openai_embedding(text: str, model: str) -> List[float]: return embedding -# FIXME - Version 1 -# # FIXME - Add all globals to summarize.py -# loaded_config = load_comprehensive_config() -# embedding_provider = loaded_config['Embeddings']['embedding_provider'] -# embedding_model = loaded_config['Embeddings']['embedding_model'] -# embedding_api_url = loaded_config['Embeddings']['embedding_api_url'] -# embedding_api_key = loaded_config['Embeddings']['embedding_api_key'] -# -# # Embedding Chunking Settings -# chunk_size = loaded_config['Embeddings']['chunk_size'] -# overlap = loaded_config['Embeddings']['overlap'] -# -# -# # FIXME - Add logging -# -# class HuggingFaceEmbedder: -# def __init__(self, model_name, timeout_seconds=120): # Default timeout of 2 minutes -# self.model_name = model_name -# self.tokenizer = None -# self.model = None -# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -# self.timeout_seconds = timeout_seconds -# self.last_used_time = 0 -# self.unload_timer = None -# -# def load_model(self): -# if self.model is None: -# self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) -# self.model = AutoModel.from_pretrained(self.model_name) -# self.model.to(self.device) -# self.last_used_time = time.time() -# self.reset_timer() -# -# def unload_model(self): -# if self.model is not None: -# del self.model -# del self.tokenizer -# if torch.cuda.is_available(): -# torch.cuda.empty_cache() -# self.model = None -# self.tokenizer = None -# if self.unload_timer: -# self.unload_timer.cancel() -# -# def reset_timer(self): -# if self.unload_timer: -# self.unload_timer.cancel() -# self.unload_timer = Timer(self.timeout_seconds, self.unload_model) -# self.unload_timer.start() -# -# def create_embeddings(self, texts): -# self.load_model() -# inputs = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512) -# inputs = {k: v.to(self.device) for k, v in inputs.items()} -# with torch.no_grad(): -# outputs = self.model(**inputs) -# embeddings = outputs.last_hidden_state.mean(dim=1) -# return embeddings.cpu().numpy() -# -# # Global variable to hold the embedder -# huggingface_embedder = None -# -# -# class RateLimiter: -# def __init__(self, max_calls, period): -# self.max_calls = max_calls -# self.period = period -# self.calls = [] -# self.lock = Lock() -# -# def __call__(self, func): -# def wrapper(*args, **kwargs): -# with self.lock: -# now = time.time() -# self.calls = [call for call in self.calls if call > now - self.period] -# if len(self.calls) >= self.max_calls: -# sleep_time = self.calls[0] - (now - self.period) -# time.sleep(sleep_time) -# self.calls.append(time.time()) -# return func(*args, **kwargs) -# return wrapper -# -# -# def exponential_backoff(max_retries=5, base_delay=1): -# def decorator(func): -# @wraps(func) -# def wrapper(*args, **kwargs): -# for attempt in range(max_retries): -# try: -# return func(*args, **kwargs) -# except Exception as e: -# if attempt == max_retries - 1: -# raise -# delay = base_delay * (2 ** attempt) -# logging.warning(f"Attempt {attempt + 1} failed. Retrying in {delay} seconds. Error: {str(e)}") -# time.sleep(delay) -# return wrapper -# return decorator -# -# -# # FIXME - refactor/setup to use config file & perform chunking -# @exponential_backoff() -# @RateLimiter(max_calls=50, period=60) -# def create_embeddings_batch(texts: List[str], provider: str, model: str, api_url: str, timeout_seconds: int = 300) -> List[List[float]]: -# global embedding_models -# -# try: -# if provider.lower() == 'huggingface': -# if model not in embedding_models: -# if model == "dunzhang/stella_en_400M_v5": -# embedding_models[model] = ONNXEmbedder(model, model_dir, timeout_seconds) -# else: -# embedding_models[model] = HuggingFaceEmbedder(model, timeout_seconds) -# embedder = embedding_models[model] -# return embedder.create_embeddings(texts) -# -# elif provider.lower() == 'openai': -# logging.debug(f"Creating embeddings for {len(texts)} texts using OpenAI API") -# return [create_openai_embedding(text, model) for text in texts] -# -# elif provider.lower() == 'local': -# response = requests.post( -# api_url, -# json={"texts": texts, "model": model}, -# headers={"Authorization": f"Bearer {embedding_api_key}"} -# ) -# if response.status_code == 200: -# return response.json()['embeddings'] -# else: -# raise Exception(f"Error from local API: {response.text}") -# else: -# raise ValueError(f"Unsupported embedding provider: {provider}") -# except Exception as e: -# logging.error(f"Error in create_embeddings_batch: {str(e)}") -# raise -# -# def create_embedding(text: str, provider: str, model: str, api_url: str) -> List[float]: -# return create_embeddings_batch([text], provider, model, api_url)[0] -# -# -# def create_openai_embedding(text: str, model: str) -> List[float]: -# embedding = get_openai_embeddings(text, model) -# return embedding -# -# -# # FIXME - refactor to use onnx embeddings callout -# def create_stella_embeddings(text: str) -> List[float]: -# if embedding_provider == 'local': -# # Load the model and tokenizer -# tokenizer = AutoTokenizer.from_pretrained("dunzhang/stella_en_400M_v5") -# model = AutoModel.from_pretrained("dunzhang/stella_en_400M_v5") -# -# # Tokenize and encode the text -# inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) -# -# # Generate embeddings -# with torch.no_grad(): -# outputs = model(**inputs) -# -# # Use the mean of the last hidden state as the sentence embedding -# embeddings = outputs.last_hidden_state.mean(dim=1) -# -# return embeddings[0].tolist() # Convert to list for consistency -# elif embedding_provider == 'openai': -# return get_openai_embeddings(text, embedding_model) -# else: -# raise ValueError(f"Unsupported embedding provider: {embedding_provider}") -# # -# # End of F -# ############################################################## -# # # ############################################################## # # diff --git a/App_Function_Libraries/RAG/RAG_Library_2.py b/App_Function_Libraries/RAG/RAG_Library_2.py index 10c8c5dfc..ad80eef1e 100644 --- a/App_Function_Libraries/RAG/RAG_Library_2.py +++ b/App_Function_Libraries/RAG/RAG_Library_2.py @@ -9,14 +9,16 @@ from typing import Dict, Any, List, Optional from App_Function_Libraries.DB.Character_Chat_DB import get_character_chats, perform_full_text_search_chat, \ - fetch_keywords_for_chats + fetch_keywords_for_chats, search_character_chat, search_character_cards, fetch_character_ids_by_keywords +from App_Function_Libraries.DB.RAG_QA_Chat_DB import search_rag_chat, search_rag_notes # # Local Imports from App_Function_Libraries.RAG.ChromaDB_Library import process_and_store_content, vector_search, chroma_client from App_Function_Libraries.RAG.RAG_Persona_Chat import perform_vector_search_chat from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_custom_openai from App_Function_Libraries.Web_Scraping.Article_Extractor_Lib import scrape_article -from App_Function_Libraries.DB.DB_Manager import search_db, fetch_keywords_for_media +from App_Function_Libraries.DB.DB_Manager import fetch_keywords_for_media, search_media_db, get_notes_by_keywords, \ + search_conversations_by_keywords from App_Function_Libraries.Utils.Utils import load_comprehensive_config from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram # @@ -40,6 +42,15 @@ # Read the configuration file config.read('config.txt') + +search_functions = { + "Media DB": search_media_db, + "RAG Chat": search_rag_chat, + "RAG Notes": search_rag_notes, + "Character Chat": search_character_chat, + "Character Cards": search_character_cards +} + # RAG pipeline function for web scraping # def rag_web_scraping_pipeline(url: str, query: str, api_choice=None) -> Dict[str, Any]: # try: @@ -117,7 +128,20 @@ # RAG Search with keyword filtering # FIXME - Update each called function to support modifiable top-k results -def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, top_k=10, apply_re_ranking=True) -> Dict[str, Any]: +def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts_top_k=10, apply_re_ranking=True, database_types: List[str] = "Media DB") -> Dict[str, Any]: + """ + Perform full text search across specified database type. + + Args: + query: Search query string + api_choice: API to use for generating the response + fts_top_k: Maximum number of results to return + keywords: Optional list of media IDs to filter results + database_types: Type of database to search ("Media DB", "RAG Chat", or "Character Chat") + + Returns: + Dictionary containing search results with content + """ log_counter("enhanced_rag_pipeline_attempt", labels={"api_choice": api_choice}) start_time = time.time() try: @@ -131,16 +155,97 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, top keyword_list = [k.strip().lower() for k in keywords.split(',')] if keywords else [] logging.debug(f"\n\nenhanced_rag_pipeline - Keywords: {keyword_list}") - # Fetch relevant media IDs based on keywords if keywords are provided - relevant_media_ids = fetch_relevant_media_ids(keyword_list) if keyword_list else None - logging.debug(f"\n\nenhanced_rag_pipeline - relevant media IDs: {relevant_media_ids}") + relevant_ids = {} + + # Fetch relevant IDs based on keywords if keywords are provided + if keyword_list: + try: + for db_type in database_types: + if db_type == "Media DB": + relevant_media_ids = fetch_relevant_media_ids(keyword_list) + relevant_ids[db_type] = relevant_media_ids + logging.debug(f"enhanced_rag_pipeline - {db_type} relevant media IDs: {relevant_media_ids}") + + elif db_type == "RAG Chat": + conversations, total_pages, total_count = search_conversations_by_keywords( + keywords=keyword_list) + relevant_conversation_ids = [conv['conversation_id'] for conv in conversations] + relevant_ids[db_type] = relevant_conversation_ids + logging.debug( + f"enhanced_rag_pipeline - {db_type} relevant conversation IDs: {relevant_conversation_ids}") + + elif db_type == "RAG Notes": + notes, total_pages, total_count = get_notes_by_keywords(keyword_list) + relevant_note_ids = [note_id for note_id, _, _, _ in notes] # Unpack note_id from the tuple + relevant_ids[db_type] = relevant_note_ids + logging.debug(f"enhanced_rag_pipeline - {db_type} relevant note IDs: {relevant_note_ids}") + + elif db_type == "Character Chat": + relevant_chat_ids = fetch_keywords_for_chats(keyword_list) + relevant_ids[db_type] = relevant_chat_ids + logging.debug(f"enhanced_rag_pipeline - {db_type} relevant chat IDs: {relevant_chat_ids}") + + elif db_type == "Character Cards": + # Assuming we have a function to fetch character IDs by keywords + relevant_character_ids = fetch_character_ids_by_keywords(keyword_list) + relevant_ids[db_type] = relevant_character_ids + logging.debug( + f"enhanced_rag_pipeline - {db_type} relevant character IDs: {relevant_character_ids}") + + else: + logging.error(f"Unsupported database type: {db_type}") + + except Exception as e: + logging.error(f"Error fetching relevant IDs: {str(e)}") + else: + relevant_ids = None + + # Extract relevant media IDs for each selected DB + # Prepare a dict to hold relevant_media_ids per DB + relevant_media_ids_dict = {} + if relevant_ids: + for db_type in database_types: + relevant_media_ids = relevant_ids.get(db_type, None) + if relevant_media_ids: + # Convert to List[str] if not None + relevant_media_ids_dict[db_type] = [str(media_id) for media_id in relevant_media_ids] + else: + relevant_media_ids_dict[db_type] = None + else: + relevant_media_ids_dict = {db_type: None for db_type in database_types} + + # Perform vector search for all selected databases + vector_results = [] + for db_type in database_types: + try: + db_relevant_ids = relevant_media_ids_dict.get(db_type) + results = perform_vector_search(query, db_relevant_ids, top_k=fts_top_k) + vector_results.extend(results) + logging.debug(f"\nenhanced_rag_pipeline - Vector search results for {db_type}: {results}") + except Exception as e: + logging.error(f"Error performing vector search on {db_type}: {str(e)}") # Perform vector search + # FIXME vector_results = perform_vector_search(query, relevant_media_ids) logging.debug(f"\n\nenhanced_rag_pipeline - Vector search results: {vector_results}") # Perform full-text search - fts_results = perform_full_text_search(query, relevant_media_ids) + #v1 + #fts_results = perform_full_text_search(query, database_type, relevant_media_ids, fts_top_k) + + # v2 + # Perform full-text search across specified databases + fts_results = [] + for db_type in database_types: + try: + db_relevant_ids = relevant_ids.get(db_type) if relevant_ids else None + db_results = perform_full_text_search(query, db_type, db_relevant_ids, fts_top_k) + fts_results.extend(db_results) + logging.debug(f"enhanced_rag_pipeline - FTS results for {db_type}: {db_results}") + except Exception as e: + logging.error(f"Error performing full-text search on {db_type}: {str(e)}") + logging.debug("\n\nenhanced_rag_pipeline - Full-text search results:") logging.debug( "\n\nenhanced_rag_pipeline - Full-text search results:\n" + "\n".join( @@ -175,8 +280,8 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, top # Update all_results based on reranking all_results = [all_results[result['id']] for result in reranked_results] - # Extract content from results (top 10 by default) - context = "\n".join([result['content'] for result in all_results[:top_k]]) + # Extract content from results (top fts_top_k by default) + context = "\n".join([result['content'] for result in all_results[:fts_top_k]]) logging.debug(f"Context length: {len(context)}") logging.debug(f"Context: {context[:200]}") @@ -208,6 +313,8 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, top "context": "" } + + # Need to write a test for this function FIXME def generate_answer(api_choice: str, context: str, query: str) -> str: # Metrics @@ -336,6 +443,7 @@ def generate_answer(api_choice: str, context: str, query: str) -> str: logging.error(f"Error in generate_answer: {str(e)}") return "An error occurred while generating the answer." + def perform_vector_search(query: str, relevant_media_ids: List[str] = None, top_k=10) -> List[Dict[str, Any]]: log_counter("perform_vector_search_attempt") start_time = time.time() @@ -344,6 +452,8 @@ def perform_vector_search(query: str, relevant_media_ids: List[str] = None, top_ try: for collection in all_collections: collection_results = vector_search(collection.name, query, k=top_k) + if not collection_results: + continue # Skip empty results filtered_results = [ result for result in collection_results if relevant_media_ids is None or result['metadata'].get('media_id') in relevant_media_ids @@ -358,29 +468,75 @@ def perform_vector_search(query: str, relevant_media_ids: List[str] = None, top_ logging.error(f"Error in perform_vector_search: {str(e)}") raise -def perform_full_text_search(query: str, relevant_media_ids: List[str] = None, fts_top_k=None) -> List[Dict[str, Any]]: - log_counter("perform_full_text_search_attempt") + +# V2 +def perform_full_text_search(query: str, database_type: str, relevant_ids: List[str] = None, fts_top_k=None) -> List[Dict[str, Any]]: + """ + Perform full-text search on a specified database type. + + Args: + query: Search query string + database_type: Type of database to search ("Media DB", "RAG Chat", "RAG Notes", "Character Chat", "Character Cards") + relevant_ids: Optional list of media IDs to filter results + fts_top_k: Maximum number of results to return + + Returns: + List of search results with content and metadata + """ + log_counter("perform_full_text_search_attempt", labels={"database_type": database_type}) start_time = time.time() + try: - fts_results = search_db(query, ["content"], "", page=1, results_per_page=fts_top_k or 10) - filtered_fts_results = [ - { - "content": result['content'], - "metadata": {"media_id": result['id']} - } - for result in fts_results - if relevant_media_ids is None or result['id'] in relevant_media_ids - ] + # Set default for fts_top_k + if fts_top_k is None: + fts_top_k = 10 + + # Call appropriate search function based on database type + if database_type not in search_functions: + raise ValueError(f"Unsupported database type: {database_type}") + + # Call the appropriate search function + results = search_functions[database_type](query, fts_top_k, relevant_ids) + search_duration = time.time() - start_time - log_histogram("perform_full_text_search_duration", search_duration) - log_counter("perform_full_text_search_success", labels={"result_count": len(filtered_fts_results)}) - return filtered_fts_results + log_histogram("perform_full_text_search_duration", search_duration, + labels={"database_type": database_type}) + log_counter("perform_full_text_search_success", + labels={"database_type": database_type, "result_count": len(results)}) + + return results + except Exception as e: - log_counter("perform_full_text_search_error", labels={"error": str(e)}) - logging.error(f"Error in perform_full_text_search: {str(e)}") + log_counter("perform_full_text_search_error", + labels={"database_type": database_type, "error": str(e)}) + logging.error(f"Error in perform_full_text_search ({database_type}): {str(e)}") raise +# v1 +# def perform_full_text_search(query: str, relevant_media_ids: List[str] = None, fts_top_k=None) -> List[Dict[str, Any]]: +# log_counter("perform_full_text_search_attempt") +# start_time = time.time() +# try: +# fts_results = search_db(query, ["content"], "", page=1, results_per_page=fts_top_k or 10) +# filtered_fts_results = [ +# { +# "content": result['content'], +# "metadata": {"media_id": result['id']} +# } +# for result in fts_results +# if relevant_media_ids is None or result['id'] in relevant_media_ids +# ] +# search_duration = time.time() - start_time +# log_histogram("perform_full_text_search_duration", search_duration) +# log_counter("perform_full_text_search_success", labels={"result_count": len(filtered_fts_results)}) +# return filtered_fts_results +# except Exception as e: +# log_counter("perform_full_text_search_error", labels={"error": str(e)}) +# logging.error(f"Error in perform_full_text_search: {str(e)}") +# raise + + def fetch_relevant_media_ids(keywords: List[str], top_k=10) -> List[int]: log_counter("fetch_relevant_media_ids_attempt", labels={"keyword_count": len(keywords)}) start_time = time.time() @@ -502,6 +658,7 @@ def enhanced_rag_pipeline_chat(query: str, api_choice: str, character_id: int, k logging.debug(f"enhanced_rag_pipeline_chat - Vector search results: {vector_results}") # Perform full-text search within the relevant chats + # FIXME - Update for DB Selection fts_results = perform_full_text_search_chat(query, relevant_chat_ids) logging.debug("enhanced_rag_pipeline_chat - Full-text search results:") logging.debug("\n".join([str(item) for item in fts_results])) diff --git a/App_Function_Libraries/RAG/RAG_QA_Chat.py b/App_Function_Libraries/RAG/RAG_QA_Chat.py index 5ec3af076..1c6580c06 100644 --- a/App_Function_Libraries/RAG/RAG_QA_Chat.py +++ b/App_Function_Libraries/RAG/RAG_QA_Chat.py @@ -12,7 +12,7 @@ from typing import List, Tuple, IO, Union # # Local Imports -from App_Function_Libraries.DB.DB_Manager import db, search_db, DatabaseError, get_media_content +from App_Function_Libraries.DB.DB_Manager import db, search_media_db, DatabaseError, get_media_content from App_Function_Libraries.RAG.RAG_Library_2 import generate_answer, enhanced_rag_pipeline from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram # @@ -89,7 +89,7 @@ def search_database(query: str) -> List[Tuple[int, str]]: log_counter("search_database_attempt") start_time = time.time() # Implement database search functionality - results = search_db(query, ["title", "content"], "", page=1, results_per_page=10) + results = search_media_db(query, ["title", "content"], "", page=1, results_per_page=10) search_duration = time.time() - start_time log_histogram("search_database_duration", search_duration) log_counter("search_database_success", labels={"result_count": len(results)}) diff --git a/App_Function_Libraries/Utils/Utils.py b/App_Function_Libraries/Utils/Utils.py index 294571455..5e2e4f0db 100644 --- a/App_Function_Libraries/Utils/Utils.py +++ b/App_Function_Libraries/Utils/Utils.py @@ -95,8 +95,6 @@ def cleanup_downloads(): ####################################################################################################################### # Config loading # - - def load_comprehensive_config(): # Get the directory of the current script (Utils.py) current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -126,25 +124,33 @@ def load_comprehensive_config(): def get_project_root(): - # Get the directory of the current file (Utils.py) + """Get the absolute path to the project root directory.""" current_dir = os.path.dirname(os.path.abspath(__file__)) - # Go up two levels to reach the project root - # Assuming the structure is: project_root/App_Function_Libraries/Utils/Utils.py project_root = os.path.dirname(os.path.dirname(current_dir)) + logging.debug(f"Project root: {project_root}") return project_root + def get_database_dir(): - """Get the database directory (/tldw/Databases/).""" + """Get the absolute path to the database directory.""" db_dir = os.path.join(get_project_root(), 'Databases') + os.makedirs(db_dir, exist_ok=True) logging.debug(f"Database directory: {db_dir}") return db_dir -def get_database_path(db_name: Union[str, os.PathLike[AnyStr]]) -> str: - """Get the full path for a database file.""" - path = os.path.join(get_database_dir(), str(db_name)) - logging.debug(f"Database path for {db_name}: {path}") + +def get_database_path(db_name: str) -> str: + """ + Get the full absolute path for a database file. + Ensures the path is always within the Databases directory. + """ + # Remove any directory traversal attempts + safe_db_name = os.path.basename(db_name) + path = os.path.join(get_database_dir(), safe_db_name) + logging.debug(f"Database path for {safe_db_name}: {path}") return path + def get_project_relative_path(relative_path: Union[str, os.PathLike[AnyStr]]) -> str: """Convert a relative path to a path relative to the project root.""" path = os.path.join(get_project_root(), str(relative_path)) @@ -280,6 +286,10 @@ def load_and_log_configs(): # Prompts - FIXME prompt_path = config.get('Prompts', 'prompt_path', fallback='Databases/prompts.db') + # Auto-Save Values + save_character_chats = config.get('Auto-Save', 'save_character_chats', fallback='False') + save_rag_chats = config.get('Auto-Save', 'save_rag_chats', fallback='False') + return { 'api_keys': { 'anthropic': anthropic_api_key, @@ -343,6 +353,10 @@ def load_and_log_configs(): 'chunk_size': chunk_size, 'overlap': overlap }, + 'auto-save': { + 'save_character_chats': save_character_chats, + 'save_rag_chats': save_rag_chats, + }, 'default_api': default_api } diff --git a/Config_Files/Backup_Config.txt b/Config_Files/Backup_Config.txt index 46d24c7cf..3d497b8b8 100644 --- a/Config_Files/Backup_Config.txt +++ b/Config_Files/Backup_Config.txt @@ -17,23 +17,24 @@ mistral_model = mistral-large-latest mistral_api_key = > custom_openai_api_key = custom_openai_api_ip = +default_api = openai [Local-API] kobold_api_IP = http://127.0.0.1:5001/api/v1/generate -kobold_api_key = +kobold_api_key = llama_api_IP = http://127.0.0.1:8080/completion -llama_api_key = -ooba_api_key = +llama_api_key = +ooba_api_key = ooba_api_IP = http://127.0.0.1:5000/v1/chat/completions tabby_api_IP = http://127.0.0.1:5000/v1/chat/completions -tabby_api_key = +tabby_api_key = vllm_api_IP = http://127.0.0.1:8000/v1/chat/completions -vllm_model = -ollama_api_IP = http://127.0.0.1:11434/api/generate -ollama_api_key = -ollama_model = +vllm_model = +ollama_api_IP = http://127.0.0.1:11434/v1/chat/completions +ollama_api_key = +ollama_model = llama3 aphrodite_api_IP = http://127.0.0.1:8080/completion -aphrodite_api_key = +aphrodite_api_key = [Processing] processing_choice = cuda @@ -41,6 +42,12 @@ processing_choice = cuda [Settings] chunk_duration = 30 words_per_second = 3 +save_character_chats = False +save_rag_chats = False + +[Auto-Save] +save_character_chats = False +save_rag_chats = False [Prompts] prompt_sample = "What is the meaning of life?" @@ -56,10 +63,14 @@ elasticsearch_port = 9200 # Additionally you can use elasticsearch as the database type, just replace `sqlite` with `elasticsearch` for `type` and provide the `elasticsearch_host` and `elasticsearch_port` of your configured ES instance. chroma_db_path = Databases/chroma_db prompts_db_path = Databases/prompts.db +rag_qa_db_path = Databases/RAG_QA_Chat.db +character_db_path = Databases/chatDB.db [Embeddings] embedding_provider = openai embedding_model = text-embedding-3-small +onnx_model_path = ./App_Function_Libraries/models/onnx_models/ +model_dir = ./App_Function_Libraries/models/embedding_models embedding_api_url = http://localhost:8080/v1/embeddings embedding_api_key = your_api_key_here chunk_size = 400 @@ -78,6 +89,14 @@ adaptive = false multi_level = false language = english +[Metrics] +log_file_path = +#os.getenv("tldw_LOG_FILE_PATH", "tldw_app_logs.json") +max_bytes = +#int(os.getenv("tldw_LOG_MAX_BYTES", 10 * 1024 * 1024)) # 10 MB +backup_count = 5 +#int(os.getenv("tldw_LOG_BACKUP_COUNT", 5)) + #[Comments] #OpenAI Models: diff --git a/Config_Files/config.txt b/Config_Files/config.txt index a0b079af5..45099b81a 100644 --- a/Config_Files/config.txt +++ b/Config_Files/config.txt @@ -17,6 +17,7 @@ mistral_model = mistral-large-latest mistral_api_key = > custom_openai_api_key = custom_openai_api_ip = +default_api = openai [Local-API] kobold_api_IP = http://127.0.0.1:5001/api/v1/generate @@ -42,6 +43,11 @@ processing_choice = cuda chunk_duration = 30 words_per_second = 3 +[Auto-Save] +save_character_chats = False +save_rag_chats = False + + [Prompts] prompt_sample = "What is the meaning of life?" video_summarize_prompt = "Above is the transcript of a video. Please read through the transcript carefully. Identify the main topics that are discussed over the course of the transcript. Then, summarize the key points about each main topic in bullet points. The bullet points should cover the key information conveyed about each topic in the video, but should be much shorter than the full transcript. Please output your bullet point summary inside tags. Do not repeat yourself while writing the summary." @@ -56,7 +62,8 @@ elasticsearch_port = 9200 # Additionally you can use elasticsearch as the database type, just replace `sqlite` with `elasticsearch` for `type` and provide the `elasticsearch_host` and `elasticsearch_port` of your configured ES instance. chroma_db_path = Databases/chroma_db prompts_db_path = Databases/prompts.db -rag_qa_db_path = Databases/rag_qa.db +rag_qa_db_path = Databases/RAG_QA_Chat.db +character_db_path = Databases/chatDB.db [Embeddings] embedding_provider = openai diff --git a/Docs/Documentation.md b/Docs/Documentation.md index d6d959abd..c735d9dd4 100644 --- a/Docs/Documentation.md +++ b/Docs/Documentation.md @@ -19,7 +19,11 @@ ------------------------------------------------------------------------------------------------------------------------ ### Introduction -- +- What is this project? +- What does it do? +- Why is it useful? +- How do I get started? +- Where can I get more help, if I need it? ------------------------------------------------------------------------------------------------------------------------ diff --git a/Docs/Issues/ISSUES.md b/Docs/Issues/ISSUES.md index 672053989..c137e0390 100644 --- a/Docs/Issues/ISSUES.md +++ b/Docs/Issues/ISSUES.md @@ -28,4 +28,9 @@ Create a blog post tldwproject.com Linux Cuda - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/user/venv/lib/python3.X/site-packages/nvidia/cudnn/lib/ \ No newline at end of file + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/user/venv/lib/python3.X/site-packages/nvidia/cudnn/lib/ + + +Export/Import From google keep + https://takeout.google.com/ + https://github.com/djsudduth/keep-it-markdown \ No newline at end of file diff --git a/Docs/Prompts/Miscellaneous/Anki_Flashcard_Creation.md b/Docs/Prompts/Miscellaneous/Anki_Flashcard_Creation.md new file mode 100644 index 000000000..dea221146 --- /dev/null +++ b/Docs/Prompts/Miscellaneous/Anki_Flashcard_Creation.md @@ -0,0 +1,74 @@ +### TITLE ### +Anki Flashcard Creation + +### AUTHOR ### +RM + +### SYSTEM ### +You are a helpful AI assistant guiding a user through the process of creating effective flashcards using Anki. + +### USER ### +# Anki Flashcard Creation Guide + +Let's create effective flashcards using proven memory principles. I'll guide you through this process step by step. + +## Initial Information +1. What specific topic/subject are you creating flashcards for? +2. What is your current knowledge level in this subject (beginner/intermediate/advanced)? +3. What is your primary goal for learning this material? + +## Card Creation Guidelines + +For each concept you want to learn, follow these principles: + +### 1. Basic Card Structure +- Front: [Question/Prompt] +- Back: [Answer/Response] + +### 2. Apply These Rules: +- Use atomic content (one fact per card) +- Make it personal and meaningful +- Ensure cards are reversible when appropriate +- Include relevant context +- Use images when helpful + +### 3. Question Types to Consider: +- Basic knowledge ("[What is X?]") +- Application ("[How would you use X in situation Y?]") +- Compare/Contrast ("[What's the difference between X and Y?]") +- Problem-solving ("[Given X conditions, solve for Y]") +- Classifications ("[What category does X belong to?]") + +### 4. Format Check +Your card should pass these tests: +- Can it be answered in under 10 seconds? +- Is the answer clear and unambiguous? +- Would you understand this card 6 months from now? +- Does it avoid "orphan cards" (cards that require other knowledge not in your deck)? + +### 5. Example Template + +FRONT: +[Topic]: [Clear, specific question] +[Additional context if needed] +[Image if relevant] + +BACK: +[Concise answer] +[Key details] +[Mnemonic device if helpful] +[Source reference if important] + +## Interactive Process + +For each concept you want to create a card for, I will: +1. Help you formulate the question +2. Review your proposed answer +3. Suggest improvements +4. Help you identify related cards needed +5. Check for common pitfalls + +Ready to begin? Please share your topic, and we'll create your first card together. + +### KEYWORDS ### +anki,flashcards,generate_cards,learning,study,education diff --git a/Docs/Screenshots/blank-front.png b/Docs/Screenshots/blank-front.png index a0b033ac7..2ef59a73b 100644 Binary files a/Docs/Screenshots/blank-front.png and b/Docs/Screenshots/blank-front.png differ diff --git a/Docs/Screenshots/tldw-run-through-blank.webm b/Docs/Screenshots/tldw-run-through-blank.webm index 554bec7f3..ea098e6f5 100644 Binary files a/Docs/Screenshots/tldw-run-through-blank.webm and b/Docs/Screenshots/tldw-run-through-blank.webm differ diff --git a/Helper_Scripts/DB-Related/migrate_db.py b/Helper_Scripts/DB-Related/migrate_db.py new file mode 100644 index 000000000..6a2dd2653 --- /dev/null +++ b/Helper_Scripts/DB-Related/migrate_db.py @@ -0,0 +1,245 @@ +import sqlite3 +import logging +from datetime import datetime +import os +from pathlib import Path +from typing import List, Tuple, Dict +import argparse +import sys +from tqdm import tqdm + + +class DatabaseMigrator: + def __init__(self, source_db_path: str, target_db_path: str, conversations_export_path: str): + self.source_db_path = Path(source_db_path) + self.target_db_path = Path(target_db_path) + self.conversations_export_path = Path(conversations_export_path) + self.source_conn = None + self.target_conn = None + + # Tables to migrate (in order of dependencies) + self.tables_to_migrate = [ + 'Media', + 'Keywords', + 'MediaKeywords', + 'MediaVersion', + 'MediaModifications', + 'Transcripts', + 'MediaChunks', + 'UnvectorizedMediaChunks', + 'DocumentVersions' + ] + + # Tables to explicitly ignore + self.tables_to_ignore = { + 'media_fts', + 'media_fts_data', + 'media_fts_idx', + 'keyword_fts', + 'ChatConversations', + 'ChatMessages' + } + + def setup_logging(self): + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler('migration.log'), + logging.StreamHandler() + ] + ) + + def connect_databases(self): + """Establish connections to both databases""" + try: + self.source_conn = sqlite3.connect(self.source_db_path) + self.target_conn = sqlite3.connect(self.target_db_path) + # Enable foreign keys + self.source_conn.execute("PRAGMA foreign_keys = ON") + self.target_conn.execute("PRAGMA foreign_keys = ON") + except Exception as e: + logging.error(f"Failed to connect to databases: {str(e)}") + raise + + def export_conversations(self): + """Export all conversations to markdown files""" + try: + # Create export directory if it doesn't exist + self.conversations_export_path.mkdir(parents=True, exist_ok=True) + + # Get all conversations + conversations_query = """ + SELECT id, media_id, media_name, conversation_name, created_at + FROM ChatConversations + """ + messages_query = """ + SELECT sender, message, timestamp + FROM ChatMessages + WHERE conversation_id = ? + ORDER BY timestamp + """ + + conversations = self.source_conn.execute(conversations_query).fetchall() + + # Add progress bar for conversation export + print("Exporting chats:") + for conv in tqdm(conversations): + conv_id, media_id, media_name, conv_name, created_at = conv + + # Create filename from conversation details + filename = f"{created_at}_{media_name or 'no_media'}_{conv_name or f'conversation_{conv_id}'}.md" + filename = "".join(c if c.isalnum() or c in ".-_" else "_" for c in filename) + + messages = self.source_conn.execute(messages_query, (conv_id,)).fetchall() + + # Write conversation to markdown file + with open(self.conversations_export_path / filename, 'w', encoding='utf-8') as f: + f.write(f"# Conversation: {conv_name or 'Untitled'}\n") + f.write(f"Media: {media_name or 'None'}\n") + f.write(f"Created: {created_at}\n\n") + f.write("---\n\n") + + for sender, message, timestamp in messages: + f.write(f"**{sender}** ({timestamp}):\n") + f.write(f"{message}\n\n") + + except Exception as e: + logging.error(f"Failed to export conversations: {str(e)}") + raise + + def migrate_table(self, table_name: str): + """Migrate a single table's data""" + try: + # Skip if table is in ignore list + if table_name in self.tables_to_ignore: + return + + # Get data + data = self.source_conn.execute(f"SELECT * FROM {table_name}").fetchall() + if not data: + return + + # Get column names + columns = [desc[0] for desc in self.source_conn.execute(f"SELECT * FROM {table_name} LIMIT 0").description] + + # Begin transaction in target database + with self.target_conn: + # Insert data with progress bar + print(f"Migrating {table_name}:") + placeholders = ','.join(['?' for _ in columns]) + insert_sql = f"INSERT INTO {table_name} ({','.join(columns)}) VALUES ({placeholders})" + + for row in tqdm(data): + try: + self.target_conn.execute(insert_sql, row) + except Exception as e: + logging.error(f"Error inserting row in {table_name}: {str(e)}") + raise + + self.target_conn.commit() + + except Exception as e: + logging.error(f"Failed to migrate table {table_name}: {str(e)}") + raise + + def perform_migration(self): + """Execute the complete migration process""" + try: + self.setup_logging() + logging.info("Starting database migration") + + self.connect_databases() + + # Export conversations first + self.export_conversations() + + # Migrate each table in order + for table in self.tables_to_migrate: + self.migrate_table(table) + + logging.info("Migration completed successfully") + + except KeyboardInterrupt: + logging.error("Migration interrupted by user") + raise + except Exception as e: + logging.error(f"Migration failed: {str(e)}") + raise + finally: + if self.source_conn: + self.source_conn.close() + if self.target_conn: + self.target_conn.close() + + +def validate_paths(source_db: str, target_db: str, export_path: str) -> None: + """Validate the provided paths""" + # Check source database exists + if not os.path.isfile(source_db): + raise ValueError(f"Source database does not exist: {source_db}") + + # Check target database path is writable + target_dir = os.path.dirname(target_db) or '.' + if not os.access(target_dir, os.W_OK): + raise ValueError(f"Cannot write to target database location: {target_dir}") + + # Check export path is writable + export_dir = Path(export_path) + try: + export_dir.mkdir(parents=True, exist_ok=True) + except Exception as e: + raise ValueError(f"Cannot create export directory: {export_path}") from e + + +def parse_arguments(): + """Parse and validate command line arguments""" + parser = argparse.ArgumentParser(description='Migrate SQLite database and export conversations to markdown.') + + parser.add_argument( + '--source-db', + required=True, + help='Path to the source database file' + ) + + parser.add_argument( + '--target-db', + required=True, + help='Path where the new database will be created' + ) + + parser.add_argument( + '--export-path', + required=True, + help='Directory where conversations will be exported as markdown' + ) + + args = parser.parse_args() + + try: + validate_paths(args.source_db, args.target_db, args.export_path) + except ValueError as e: + parser.error(str(e)) + + return args + + +def main(): + try: + # Parse command line arguments + args = parse_arguments() + + # Create and run migrator + migrator = DatabaseMigrator(args.source_db, args.target_db, args.export_path) + migrator.perform_migration() + + except KeyboardInterrupt: + print("\nMigration interrupted by user") + sys.exit(1) + except Exception as e: + logging.error(f"Migration failed: {str(e)}") + sys.exit(1) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/Helper_Scripts/Installer_Scripts/Windows_Install_Update.bat b/Helper_Scripts/Installer_Scripts/Windows_Install_Update.bat index aa189defd..baf6d6cc9 100644 --- a/Helper_Scripts/Installer_Scripts/Windows_Install_Update.bat +++ b/Helper_Scripts/Installer_Scripts/Windows_Install_Update.bat @@ -156,7 +156,7 @@ move ffmpeg\ffmpeg-master-latest-win64-gpl\bin\ffmpeg.exe . rmdir /s /q ffmpeg del ffmpeg.zip mkdir .\Bin -move ffmpeg .\Bin +move ffmpeg.exe .\Bin goto :eof :cleanup diff --git a/README.md b/README.md index 2723171dc..8a5a79c8d 100644 --- a/README.md +++ b/README.md @@ -14,21 +14,44 @@ ### [Public Demo on HuggingFace Spaces](https://huggingface.co/spaces/oceansweep/Vid-Summarizer/?__theme=dark) - Demo is broken due to a bug in Huggingface spaces/Gradio - **Please Note:** YouTube blocks requests from the demo. You have to provide a logged-in session cookie to bypass it :frowning_face: - Placeholder content is included for the demo. HuggingFace API is also setup in it, so you can select that as your API.) +- **HEADS UP:** If you're updating from a prior version, your Media DB is not compatible with the new version. You'll need to start fresh. + - I've written a script to help you migrate your data from the old DB to the new one. `Helper_Scripts/DB-Related/migrate_db.py`. + - Process to migrate your data: + 1. Install/run the new version of the app. This will create a new `media_summary.db` file in the `Databases` directory. + 2. Run the `migrate_db.py` script with the old and new DB paths as arguments. - `python migrate_db.py --source media_summary_old.db --target media_summary_new.db --export-path .\` + 3. This will migrate all your data from the old DB to the new one, and export the saved conversations to the `export-path` you specify. + 4. Re-import any/all saved conversations into the new RAG_QA_Chat.db #### ![Video Walkthrough of a Fresh Install](Docs/Screenshots/tldw-run-through-blank.webm) -#### Screenshot of the Frontpage ![Screenshot](Docs/Screenshots/blank-front.png)` +#### Screenshot of the Frontpage ![Screenshot](Docs/Screenshots/blank-front.png) #### Key Features: -- Full-text+RAG search across all ingested content (RAG being BM25 + Vector Search/Contextual embeddings + Re-ranking). -- Local LLM inference for offline usage and chat (via `llamafile`/`HuggingFace Transformers`). -- Local Embeddings generation for RAG search (via `llamafile`/`llama.cpp`/`HuggingFace Transformers`). -- Build up a personal knowledge archive, then turn around and use the LLM to help you learn it at a pace your comfortable with. +- Ingest(Transcribe/convert to markdown) content from (multiple) URLs or local files (video, audio, documents, web articles, books, mediawiki dumps) -> Summarize/Analyze -> Chat with/about the content.- Build up a personal knowledge archive, then turn around and use the LLM to help you learn it at a pace your comfortable with. +- **Full Plaintext & RAG Search Capability** Search across all ingested content via RAG or 'old-fashioned non-LLM search' (RAG being BM25 + Vector Search/Contextual embeddings + Re-ranking + Contextual Retrieval). + - Search by content, title, author, URL, or tags, with support for meta-tags, so that you can have the equivalent of 'folders' for your content (and tags). + - If you'd like to see my notes on RAG: see `./Docs/RAG_Notes.md` + - Notes support, ala NotebookLM, so you can keep track of your thoughts and ideas while chatting/learning, with the ability to search across them or use them for RAG. +- **Local LLM inference for offline usage and chat** - via `llamafile`/`HuggingFace Transformers`. +- **4 Different Chat UI styles** - Regular chat, Stacked chat, Multi-Response chat(1 Prompt, 3 APIs) and 4 Separate API chats on one page. +- **Local Embeddings Generation for RAG Search** - via `llamafile`/`llama.cpp`/`HuggingFace Transformers`. - Also writing tools! Grammar/Style checker, Tone Analyzer, Writing editor(feedback), and more. -- Full Character Chat Support - Create/Edit & Import/Export Character Cards, and chat with them. +- **Full Character Chat Support** - Create/Edit & Import/Export Character Cards, and chat with them. +- **Arxiv API Integration** - Search and ingest papers from Arxiv. +- **Chat Workflows** - A way to string together multiple questions and responses into a single chat. - Use it to create a 'workflow' for a specific task. Configured via a JSON file. +- **Import Obsidian Notes/Vault** - Import Obsidian Vaults into the DB. (Imported notes are automatically parsed for tags and titles) +- **Backup Management** - A way to back up the DBs, view backups, and restore from a backup. (4 SQLite DBs: Media, Character Chats, RAG Chats, Embeddings) +- **Trashcan Support** - A way to 'soft' delete content, and restore it if needed. (Helps with accidental deletions) - Trashcan is only for the MediaDB. +- **Support for 7 Local LLM APIs:** `Llama.cpp`, `Kobold.cpp`, `Oobabooga`, `TabbyAPI`, `vLLM`, `Ollama`, `Aphrodite`, `Custom OpenAI API`. +- **Support for 8 Commercial APIs:** `Claude Sonnet 3.5`, `Cohere Command R+`, `DeepSeek`, `Groq`, `HuggingFace`, `Mistral`, `OpenAI`, `OpenRouter`. +- **Local Audio Recording with Transcription** - Record audio locally and transcribe it. +- **Structured Prompt Creation and Management** - Create prompts using a structured approach, and then edit and use them in your chats. Or delete them. + - Also have the ability to import prompts individually or in bulk. As well as export them as markdown documents. + - See `./Docs/Prompts/` for examples of prompts. and `./Docs/Propmts/TEMPLATE.md` for the prompt template used in tldw. +- Features to come: Anki Flashcard Deck Editing (Creation is in), Mindmap creation from content(currently in under `Utilities`, uses PlantUML), better document handling, migration to a FastAPI backend(Gradio is a placeholder UI), and more. #### The original scripts by `the-crypt-keeper` for transcribing and summarizing youtube videos are available here: [scripts here](https://github.com/the-crypt-keeper/tldw/tree/main/tldw-original-scripts) @@ -321,16 +344,14 @@ You can view the full roadmap on the [GitHub Issues page](https://github.com/rmu - These are just the 'standard smaller' models I recommend, there are many more out there, and you can use any of them with this project. - One should also be aware that people create 'fine-tunes' and 'merges' of existing models, to create new models that are more suited to their needs. - This can result in models that may be better at some tasks but worse at others, so it's important to test and see what works best for you. -- MS Phi-3.5-mini-128k(32k effective context, censored output): https://huggingface.co/bartowski/Phi-3.5-mini-instruct-GGUF - - Fine-tuned to be uncensored somewhat: https://huggingface.co/bartowski/Phi-3.5-mini-instruct_Uncensored-GGUF -- AWS MegaBeam Mistral (32k effective context): https://huggingface.co/bartowski/MegaBeam-Mistral-7B-512k-GGUF -- Mistral Nemo Instruct 2407 - https://huggingface.co/QuantFactory/Mistral-Nemo-Instruct-2407-GGUF - Llama 3.1 - The native llamas will give you censored output by default, but you can jailbreak them, or use a finetune which has attempted to tune out their refusals. - 8B: https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF +- Mistral Nemo Instruct 2407 - https://huggingface.co/QuantFactory/Mistral-Nemo-Instruct-2407-GGUF +- AWS MegaBeam Mistral (32k effective context): https://huggingface.co/bartowski/MegaBeam-Mistral-7B-512k-GGUF - Mistral Small: https://huggingface.co/bartowski/Mistral-Small-Instruct-2409-GGUF - Cohere Command-R - Command-R https://huggingface.co/bartowski/c4ai-command-r-v01-GGUF / Aug2024 version: https://huggingface.co/bartowski/c4ai-command-r-08-2024-GGUF -- Qwen 2.5 Series(haven't tested these ones yet but they seem promising, almost certainly censored): https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e +- Qwen 2.5 Series(Pretty powerful, less pop-culture knowledge and censored somewhat): https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e - 2.5-3B: https://huggingface.co/Qwen/Qwen2.5-3B-Instruct-GGUF - 7B: https://huggingface.co/Qwen/Qwen2.5-7B-Instruct-GGUF - 14B: https://huggingface.co/Qwen/Qwen2.5-14B-Instruct-GGUF @@ -466,20 +487,23 @@ None of these companies exist to provide AI services in 2024. They’re only doi 1. Lost in the Middle: How Language Models Use Long Contexts(2023) - https://arxiv.org/abs/2307.03172 - `We analyze the performance of language models on two tasks that require identifying relevant information in their input contexts: multi-document question answering and key-value retrieval. We find that performance can degrade significantly when changing the position of relevant information, indicating that current language models do not robustly make use of information in long input contexts. In particular, we observe that performance is often highest when relevant information occurs at the beginning or end of the input context, and significantly degrades when models must access relevant information in the middle of long contexts, even for explicitly long-context models` - 2. [RULER: What's the Real Context Size of Your Long-Context Language Models?(2024)](https://arxiv.org/abs/2404.06654) - - `The needle-in-a-haystack (NIAH) test, which examines the ability to retrieve a piece of information (the "needle") from long distractor texts (the "haystack"), has been widely adopted to evaluate long-context language models (LMs). However, this simple retrieval-based test is indicative of only a superficial form of long-context understanding. To provide a more comprehensive evaluation of long-context LMs, we create a new synthetic benchmark RULER with flexible configurations for customized sequence length and task complexity. RULER expands upon the vanilla NIAH test to encompass variations with diverse types and quantities of needles. Moreover, RULER introduces new task categories multi-hop tracing and aggregation to test behaviors beyond searching from context. We evaluate ten long-context LMs with 13 representative tasks in RULER. Despite achieving nearly perfect accuracy in the vanilla NIAH test, all models exhibit large performance drops as the context length increases. While these models all claim context sizes of 32K tokens or greater, only four models (GPT-4, Command-R, Yi-34B, and Mixtral) can maintain satisfactory performance at the length of 32K. Our analysis of Yi-34B, which supports context length of 200K, reveals large room for improvement as we increase input length and task complexity.` - 3. [Same Task, More Tokens: the Impact of Input Length on the Reasoning Performance of Large Language Models(2024)](https://arxiv.org/abs/2402.14848) + 2. [Same Task, More Tokens: the Impact of Input Length on the Reasoning Performance of Large Language Models(2024)](https://arxiv.org/abs/2402.14848) - `Our findings show a notable degradation in LLMs' reasoning performance at much shorter input lengths than their technical maximum. We show that the degradation trend appears in every version of our dataset, although at different intensities. Additionally, our study reveals that the traditional metric of next word prediction correlates negatively with performance of LLMs' on our reasoning dataset. We analyse our results and identify failure modes that can serve as useful guides for future research, potentially informing strategies to address the limitations observed in LLMs.` - 4. Abliteration (Uncensoring LLMs) + 3. Why Does the Effective Context Length of LLMs Fall Short?(2024) + - https://arxiv.org/abs/2410.18745 + - ` Advancements in distributed training and efficient attention mechanisms have significantly expanded the context window sizes of large language models (LLMs). However, recent work reveals that the effective context lengths of open-source LLMs often fall short, typically not exceeding half of their training lengths. In this work, we attribute this limitation to the left-skewed frequency distribution of relative positions formed in LLMs pretraining and post-training stages, which impedes their ability to effectively gather distant information. To address this challenge, we introduce ShifTed Rotray position embeddING (STRING). STRING shifts well-trained positions to overwrite the original ineffective positions during inference, enhancing performance within their existing training lengths. Experimental results show that without additional training, STRING dramatically improves the performance of the latest large-scale models, such as Llama3.1 70B and Qwen2 72B, by over 10 points on popular long-context benchmarks RULER and InfiniteBench, establishing new state-of-the-art results for open-source LLMs. Compared to commercial models, Llama 3.1 70B with \method even achieves better performance than GPT-4-128K and clearly surpasses Claude 2 and Kimi-chat.` + 4. [RULER: What's the Real Context Size of Your Long-Context Language Models?(2024)](https://arxiv.org/abs/2404.06654) + - `The needle-in-a-haystack (NIAH) test, which examines the ability to retrieve a piece of information (the "needle") from long distractor texts (the "haystack"), has been widely adopted to evaluate long-context language models (LMs). However, this simple retrieval-based test is indicative of only a superficial form of long-context understanding. To provide a more comprehensive evaluation of long-context LMs, we create a new synthetic benchmark RULER with flexible configurations for customized sequence length and task complexity. RULER expands upon the vanilla NIAH test to encompass variations with diverse types and quantities of needles. Moreover, RULER introduces new task categories multi-hop tracing and aggregation to test behaviors beyond searching from context. We evaluate ten long-context LMs with 13 representative tasks in RULER. Despite achieving nearly perfect accuracy in the vanilla NIAH test, all models exhibit large performance drops as the context length increases. While these models all claim context sizes of 32K tokens or greater, only four models (GPT-4, Command-R, Yi-34B, and Mixtral) can maintain satisfactory performance at the length of 32K. Our analysis of Yi-34B, which supports context length of 200K, reveals large room for improvement as we increase input length and task complexity.` + 5. Abliteration (Uncensoring LLMs) - [Uncensor any LLM with abliteration - Maxime Labonne(2024)](https://huggingface.co/blog/mlabonne/abliteration) - 5. Retrieval-Augmented-Generation + 6. Retrieval-Augmented-Generation - [Retrieval-Augmented Generation for Large Language Models: A Survey](https://arxiv.org/abs/2312.10997) - https://arxiv.org/abs/2312.10997 - `Retrieval-Augmented Generation (RAG) has emerged as a promising solution by incorporating knowledge from external databases. This enhances the accuracy and credibility of the generation, particularly for knowledge-intensive tasks, and allows for continuous knowledge updates and integration of domain-specific information. RAG synergistically merges LLMs' intrinsic knowledge with the vast, dynamic repositories of external databases. This comprehensive review paper offers a detailed examination of the progression of RAG paradigms, encompassing the Naive RAG, the Advanced RAG, and the Modular RAG. It meticulously scrutinizes the tripartite foundation of RAG frameworks, which includes the retrieval, the generation and the augmentation techniques. The paper highlights the state-of-the-art technologies embedded in each of these critical components, providing a profound understanding of the advancements in RAG systems. Furthermore, this paper introduces up-to-date evaluation framework and benchmark. At the end, this article delineates the challenges currently faced and points out prospective avenues for research and development. ` - 6. Prompt Engineering + 7. Prompt Engineering - Prompt Engineering Guide: https://www.promptingguide.ai/ & https://github.com/dair-ai/Prompt-Engineering-Guide - 'The Prompt Report' - https://arxiv.org/abs/2406.06608 - 7. Bias and Fairness in LLMs + 8. Bias and Fairness in LLMs - [ChatGPT Doesn't Trust Chargers Fans: Guardrail Sensitivity in Context](https://arxiv.org/abs/2407.06866) - `While the biases of language models in production are extensively documented, the biases of their guardrails have been neglected. This paper studies how contextual information about the user influences the likelihood of an LLM to refuse to execute a request. By generating user biographies that offer ideological and demographic information, we find a number of biases in guardrail sensitivity on GPT-3.5. Younger, female, and Asian-American personas are more likely to trigger a refusal guardrail when requesting censored or illegal information. Guardrails are also sycophantic, refusing to comply with requests for a political position the user is likely to disagree with. We find that certain identity groups and seemingly innocuous information, e.g., sports fandom, can elicit changes in guardrail sensitivity similar to direct statements of political ideology. For each demographic category and even for American football team fandom, we find that ChatGPT appears to infer a likely political ideology and modify guardrail behavior accordingly.` - **Tools & Libraries** @@ -567,7 +591,7 @@ In order of attempts: ------------ ### Credits -- [The original version of this project by @the-crypt-keeper](https://github.com/the-crypt-keeper/tldw) +- [The original version of this project by @the-crypt-keeper](https://github.com/the-crypt-keeper/tldw/tree/main/tldw-original-scripts) - [yt-dlp](https://github.com/yt-dlp/yt-dlp) - [ffmpeg](https://github.com/FFmpeg/FFmpeg) - [faster_whisper](https://github.com/SYSTRAN/faster-whisper) diff --git a/Tests/ChromaDB/test_chromadb.py b/Tests/ChromaDB/test_chromadb.py index 822401580..f3526f6b4 100644 --- a/Tests/ChromaDB/test_chromadb.py +++ b/Tests/ChromaDB/test_chromadb.py @@ -20,7 +20,7 @@ # Local Imports from App_Function_Libraries.RAG.ChromaDB_Library import ( - preprocess_all_content, process_and_store_content, check_embedding_status, + process_and_store_content, check_embedding_status, reset_chroma_collection, vector_search, store_in_chroma, batched, situate_context, schedule_embedding, embedding_api_url ) @@ -86,23 +86,6 @@ def mock_mark_media_processed(mocker): """Fixture to mock mark_media_as_processed.""" return mocker.patch("App_Function_Libraries.RAG.ChromaDB_Library.mark_media_as_processed") -def test_preprocess_all_content(mock_unprocessed_media, mock_process_and_store, mock_mark_media_processed, mock_database, mocker): - # Mock get_unprocessed_media to return unprocessed media - mocker.patch('App_Function_Libraries.RAG.ChromaDB_Library.get_unprocessed_media', return_value=mock_unprocessed_media) - - preprocess_all_content(database=mock_database, create_contextualized=False) - - mock_process_and_store.assert_called_once_with( - database=mock_database, - content="Test Content", - collection_name="video_1", - media_id=1, - file_name="test_file.mp4", - create_embeddings=True, - create_contextualized=False, - api_name="gpt-3.5-turbo" - ) - mock_mark_media_processed.assert_called_once_with(mock_database, 1) ############################## # Test: process_and_store_content diff --git a/Tests/Embeddings/__init__.py b/Tests/Embeddings/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/Tests/RAG/test_RAG_Library_2.py b/Tests/RAG/test_RAG_Library_2.py index d0dc7d057..b31353cd1 100644 --- a/Tests/RAG/test_RAG_Library_2.py +++ b/Tests/RAG/test_RAG_Library_2.py @@ -1,9 +1,9 @@ # Tests/RAG/test_rag_functions.py - +import configparser import os import sys -import unittest -from unittest.mock import patch, MagicMock +import pytest +from unittest.mock import MagicMock from typing import List, Dict, Any # Adjust the path to the parent directory of App_Function_Libraries @@ -15,455 +15,221 @@ from App_Function_Libraries.RAG.RAG_Library_2 import ( fetch_relevant_media_ids, perform_vector_search, - perform_full_text_search + perform_full_text_search, + enhanced_rag_pipeline, + enhanced_rag_pipeline_chat, + generate_answer, + fetch_relevant_chat_ids, + fetch_all_chat_ids, + filter_results_by_keywords, + extract_media_id_from_result ) -class TestRAGFunctions(unittest.TestCase): - """ - Unit tests for RAG-related functions. - """ - - @patch('App_Function_Libraries.RAG.RAG_Library_2.fetch_keywords_for_media') - def test_fetch_relevant_media_ids_success(self, mock_fetch_keywords_for_media): - """ - Test fetch_relevant_media_ids with valid keywords. - """ - # Setup mock return values - mock_fetch_keywords_for_media.side_effect = lambda keyword: { +def test_fetch_relevant_media_ids_success(mocker): + """Test fetch_relevant_media_ids with valid keywords.""" + mock_fetch_keywords_for_media = mocker.patch( + 'App_Function_Libraries.RAG.RAG_Library_2.fetch_keywords_for_media', + side_effect=lambda keyword: { 'geography': [1, 2], 'cities': [2, 3, 4] }.get(keyword, []) - - # Input keywords - keywords = ['geography', 'cities'] - - # Call the function - result = fetch_relevant_media_ids(keywords) - - # Expected result is the union of media_ids: [1,2,3,4] - self.assertEqual(sorted(result), [1, 2, 3, 4]) - - # Assert fetch_keywords_for_media was called correctly - mock_fetch_keywords_for_media.assert_any_call('geography') - mock_fetch_keywords_for_media.assert_any_call('cities') - self.assertEqual(mock_fetch_keywords_for_media.call_count, 2) - - @patch('App_Function_Libraries.RAG.RAG_Library_2.fetch_keywords_for_media') - def test_fetch_relevant_media_ids_empty_keywords(self, mock_fetch_keywords_for_media): - """ - Test fetch_relevant_media_ids with an empty keywords list. - """ - keywords = [] - result = fetch_relevant_media_ids(keywords) - self.assertEqual(result, []) - mock_fetch_keywords_for_media.assert_not_called() - - @patch('App_Function_Libraries.RAG.RAG_Library_2.fetch_keywords_for_media') - @patch('App_Function_Libraries.RAG.RAG_Library_2.logging') - def test_fetch_relevant_media_ids_exception(self, mock_logging, mock_fetch_keywords_for_media): - """ - Test fetch_relevant_media_ids when fetch_keywords_for_media raises an exception. - """ - # Configure the mock to raise an exception - mock_fetch_keywords_for_media.side_effect = Exception("Database error") - - keywords = ['geography', 'cities'] - result = fetch_relevant_media_ids(keywords) - - # The function should return an empty list upon exception - self.assertEqual(result, []) - - # Assert that errors were logged for both keywords - mock_logging.error.assert_any_call("Error fetching relevant media IDs for keyword 'geography': Database error") - mock_logging.error.assert_any_call("Error fetching relevant media IDs for keyword 'cities': Database error") - self.assertEqual(mock_logging.error.call_count, 2) - - @patch('App_Function_Libraries.RAG.RAG_Library_2.vector_search') - @patch('App_Function_Libraries.RAG.RAG_Library_2.chroma_client') - def test_perform_vector_search_with_relevant_media_ids(self, mock_chroma_client, mock_vector_search): - """ - Test perform_vector_search with relevant_media_ids provided. - """ - # Setup mock chroma_client to return a list of collections - mock_collection = MagicMock() - mock_collection.name = 'collection1' - mock_chroma_client.list_collections.return_value = [mock_collection] - - # Setup mock vector_search to return search results - mock_vector_search.return_value = [ - {'content': 'Document 1', 'metadata': {'media_id': 1}}, - {'content': 'Document 2', 'metadata': {'media_id': 2}}, - {'content': 'Document 3', 'metadata': {'media_id': 3}}, - ] - - # Input parameters - query = 'sample query' - relevant_media_ids = [1, 3] - - # Call the function - result = perform_vector_search(query, relevant_media_ids) - - # Expected to filter out media_id 2 - expected = [ - {'content': 'Document 1', 'metadata': {'media_id': 1}}, - {'content': 'Document 3', 'metadata': {'media_id': 3}}, - ] - self.assertEqual(result, expected) - - # Assert chroma_client.list_collections was called once - mock_chroma_client.list_collections.assert_called_once() - - # Assert vector_search was called with correct arguments - mock_vector_search.assert_called_once_with('collection1', query, k=10) - - @patch('App_Function_Libraries.RAG.RAG_Library_2.vector_search') - @patch('App_Function_Libraries.RAG.RAG_Library_2.chroma_client') - def test_perform_vector_search_without_relevant_media_ids(self, mock_chroma_client, mock_vector_search): - """ - Test perform_vector_search without relevant_media_ids (None). - """ - # Setup mock chroma_client to return a list of collections - mock_collection = MagicMock() - mock_collection.name = 'collection1' - mock_chroma_client.list_collections.return_value = [mock_collection] - - # Setup mock vector_search to return search results - mock_vector_search.return_value = [ - {'content': 'Document 1', 'metadata': {'media_id': 1}}, - {'content': 'Document 2', 'metadata': {'media_id': 2}}, - ] - - # Input parameters - query = 'sample query' - relevant_media_ids = None - - # Call the function - result = perform_vector_search(query, relevant_media_ids) - - # Expected to return all results - expected = [ - {'content': 'Document 1', 'metadata': {'media_id': 1}}, - {'content': 'Document 2', 'metadata': {'media_id': 2}}, - ] - self.assertEqual(result, expected) - - # Assert chroma_client.list_collections was called once - mock_chroma_client.list_collections.assert_called_once() - - # Assert vector_search was called with correct arguments - mock_vector_search.assert_called_once_with('collection1', query, k=10) - - @patch('App_Function_Libraries.RAG.RAG_Library_2.search_db') - def test_perform_full_text_search_with_relevant_media_ids(self, mock_search_db): - """ - Test perform_full_text_search with relevant_media_ids provided. - """ - # Setup mock search_db to return search results - mock_search_db.return_value = [ - {'content': 'Full text document 1', 'id': 1}, - {'content': 'Full text document 2', 'id': 2}, - {'content': 'Full text document 3', 'id': 3}, - ] - - # Input parameters - query = 'full text query' - relevant_media_ids = [1, 3] - - # Call the function - result = perform_full_text_search(query, relevant_media_ids, fts_top_k=10) - - # Expected to filter out id 2 - expected = [ - {'content': 'Full text document 1', 'metadata': {'media_id': 1}}, - {'content': 'Full text document 3', 'metadata': {'media_id': 3}}, - ] - self.assertEqual(result, expected) - - # Assert search_db was called with correct arguments - mock_search_db.assert_called_once_with( - query, ['content'], '', page=1, results_per_page=10) - - @patch('App_Function_Libraries.RAG.RAG_Library_2.search_db') - def test_perform_full_text_search_without_relevant_media_ids(self, mock_search_db): - """ - Test perform_full_text_search without relevant_media_ids (None). - """ - # Setup mock search_db to return search results - mock_search_db.return_value = [ - {'content': 'Full text document 1', 'id': 1}, - {'content': 'Full text document 2', 'id': 2}, - ] - - # Input parameters - query = 'full text query' - relevant_media_ids = None - - # Call the function - result = perform_full_text_search(query, relevant_media_ids) - - # Expected to return all results - expected = [ - {'content': 'Full text document 1', 'metadata': {'media_id': 1}}, - {'content': 'Full text document 2', 'metadata': {'media_id': 2}}, - ] - self.assertEqual(result, expected) - - # Assert search_db was called with correct arguments - mock_search_db.assert_called_once_with( - query, ['content'], '', page=1, results_per_page=10) - - @patch('App_Function_Libraries.RAG.RAG_Library_2.search_db') - def test_perform_full_text_search_empty_results(self, mock_search_db): - """ - Test perform_full_text_search when search_db returns no results. - """ - # Setup mock search_db to return empty list - mock_search_db.return_value = [] - - # Input parameters - query = 'full text query' - relevant_media_ids = [1, 2] - - # Call the function - result = perform_full_text_search(query, relevant_media_ids) - - # Expected to return an empty list - expected = [] - self.assertEqual(result, expected) - - # Assert search_db was called with correct arguments - mock_search_db.assert_called_once_with( - query, ['content'], '', page=1, results_per_page=10) - - @patch('App_Function_Libraries.RAG.RAG_Library_2.fetch_keywords_for_media') - @patch('App_Function_Libraries.RAG.RAG_Library_2.logging') - def test_fetch_relevant_media_ids_partial_failure(self, mock_logging, mock_fetch_keywords_for_media): - """ - Test fetch_relevant_media_ids when fetch_keywords_for_media partially fails. - """ - - # Configure the mock to raise an exception for one keyword - def side_effect(keyword): - if keyword == 'geography': - return [1, 2] - elif keyword == 'cities': - raise Exception("Database error") - return [] - - mock_fetch_keywords_for_media.side_effect = side_effect - - keywords = ['geography', 'cities'] - result = fetch_relevant_media_ids(keywords) - - # The function should still return media_ids for 'geography' and skip 'cities' - self.assertEqual(sorted(result), [1, 2]) - - # Assert that an error was logged for 'cities' - mock_logging.error.assert_called_once_with( - "Error fetching relevant media IDs for keyword 'cities': Database error") - - @patch('App_Function_Libraries.RAG.RAG_Library_2.chroma_client') - @patch('App_Function_Libraries.RAG.RAG_Library_2.vector_search') - def test_perform_vector_search_no_collections(self, mock_vector_search, mock_chroma_client): - """ - Test perform_vector_search when there are no collections. - """ - # Setup mock chroma_client to return an empty list of collections - mock_chroma_client.list_collections.return_value = [] - - # Input parameters - query = 'sample query' - relevant_media_ids = [1, 2] - - # Call the function - result = perform_vector_search(query, relevant_media_ids) - - # Expected to return an empty list since there are no collections - expected = [] - self.assertEqual(result, expected) - - # Assert chroma_client.list_collections was called once - mock_chroma_client.list_collections.assert_called_once() - - # Assert vector_search was not called since there are no collections - mock_vector_search.assert_not_called() - - @patch('App_Function_Libraries.RAG.RAG_Library_2.fetch_keywords_for_media') - def test_fetch_relevant_media_ids_duplicate_media_ids(self, mock_fetch_keywords_for_media): - """ - Test fetch_relevant_media_ids with duplicate media_ids across keywords. - """ - # Setup mock return values with overlapping media_ids - mock_fetch_keywords_for_media.side_effect = lambda keyword: { - 'science': [1, 2, 3], - 'technology': [3, 4, 5], - 'engineering': [5, 6], - }.get(keyword, []) - - # Input keywords - keywords = ['science', 'technology', 'engineering'] - - # Call the function - result = fetch_relevant_media_ids(keywords) - - # Expected result is the unique union of media_ids: [1,2,3,4,5,6] - self.assertEqual(sorted(result), [1, 2, 3, 4, 5, 6]) - - # Assert fetch_keywords_for_media was called correctly - mock_fetch_keywords_for_media.assert_any_call('science') - mock_fetch_keywords_for_media.assert_any_call('technology') - mock_fetch_keywords_for_media.assert_any_call('engineering') - self.assertEqual(mock_fetch_keywords_for_media.call_count, 3) - - @patch('App_Function_Libraries.RAG.RAG_Library_2.search_db') - def test_perform_full_text_search_case_insensitive_filtering(self, mock_search_db): - """ - Test perform_full_text_search with case-insensitive filtering of media_ids. - """ - # Setup mock search_db to return mixed-case media_ids - mock_search_db.return_value = [ - {'content': 'Full text document 1', 'id': '1'}, - {'content': 'Full text document 2', 'id': '2'}, - {'content': 'Full text document 3', 'id': '3'}, - ] - - # Input parameters with media_ids as strings - query = 'full text query' - relevant_media_ids = ['1', '3'] - - # Call the function - result = perform_full_text_search(query, relevant_media_ids) - - # Expected to filter out id '2' - expected = [ - {'content': 'Full text document 1', 'metadata': {'media_id': '1'}}, - {'content': 'Full text document 3', 'metadata': {'media_id': '3'}}, - ] - self.assertEqual(result, expected) - - # Assert search_db was called with correct arguments - mock_search_db.assert_called_once_with( - query, ['content'], '', page=1, results_per_page=10) - - @patch('App_Function_Libraries.RAG.RAG_Library_2.search_db') - def test_perform_full_text_search_multiple_pages(self, mock_search_db): - """ - Test perform_full_text_search with multiple pages of results. - Note: The current implementation fetches only the first page. - """ - # Setup mock search_db to return results from the first page - mock_search_db.return_value = [ - {'content': 'Full text document 1', 'id': 1}, - {'content': 'Full text document 2', 'id': 2}, - {'content': 'Full text document 3', 'id': 3}, - {'content': 'Full text document 4', 'id': 4}, - {'content': 'Full text document 5', 'id': 5}, - ] - - # Input parameters - query = 'full text query' - relevant_media_ids = [1, 2, 3, 4, 5] - - # Call the function - result = perform_full_text_search(query, relevant_media_ids) - - # Expected to return all results - expected = [ - {'content': 'Full text document 1', 'metadata': {'media_id': 1}}, - {'content': 'Full text document 2', 'metadata': {'media_id': 2}}, - {'content': 'Full text document 3', 'metadata': {'media_id': 3}}, - {'content': 'Full text document 4', 'metadata': {'media_id': 4}}, - {'content': 'Full text document 5', 'metadata': {'media_id': 5}}, - ] - self.assertEqual(result, expected) - - # Assert search_db was called with correct arguments - mock_search_db.assert_called_once_with( - query, ['content'], '', page=1, results_per_page=10) - - @patch('App_Function_Libraries.RAG.RAG_Library_2.chroma_client') - @patch('App_Function_Libraries.RAG.RAG_Library_2.vector_search') - def test_perform_vector_search_multiple_collections(self, mock_vector_search, mock_chroma_client): - """ - Test perform_vector_search with multiple collections. - """ - # Setup mock chroma_client to return multiple collections - mock_collection1 = MagicMock() - mock_collection1.name = 'collection1' - mock_collection2 = MagicMock() - mock_collection2.name = 'collection2' - mock_chroma_client.list_collections.return_value = [mock_collection1, mock_collection2] - - # Setup mock vector_search to return different results for each collection - def vector_search_side_effect(collection_name, query, k): - if collection_name == 'collection1': - return [ - {'content': 'Collection1 Document 1', 'metadata': {'media_id': 1}}, - {'content': 'Collection1 Document 2', 'metadata': {'media_id': 2}}, - ] - elif collection_name == 'collection2': - return [ - {'content': 'Collection2 Document 1', 'metadata': {'media_id': 3}}, - {'content': 'Collection2 Document 2', 'metadata': {'media_id': 4}}, - ] - return [] - - mock_vector_search.side_effect = vector_search_side_effect - - # Input parameters - query = 'sample query' - relevant_media_ids = [2, 3] - - # Call the function - result = perform_vector_search(query, relevant_media_ids) - - # Expected to filter and include media_id 2 and 3 - expected = [ - {'content': 'Collection1 Document 2', 'metadata': {'media_id': 2}}, - {'content': 'Collection2 Document 1', 'metadata': {'media_id': 3}}, - ] - self.assertEqual(result, expected) - - # Assert chroma_client.list_collections was called once - mock_chroma_client.list_collections.assert_called_once() - - # Assert vector_search was called twice with correct arguments - mock_vector_search.assert_any_call('collection1', query, k=10) - mock_vector_search.assert_any_call('collection2', query, k=10) - self.assertEqual(mock_vector_search.call_count, 2) - - @patch('App_Function_Libraries.RAG.RAG_Library_2.search_db') - def test_perform_full_text_search_partial_matches(self, mock_search_db): - """ - Test perform_full_text_search where some media_ids do not match the relevant_media_ids. - """ - # Setup mock search_db to return search results - mock_search_db.return_value = [ - {'content': 'Full text document 1', 'id': 1}, - {'content': 'Full text document 2', 'id': 2}, - {'content': 'Full text document 3', 'id': 3}, - {'content': 'Full text document 4', 'id': 4}, - ] - - # Input parameters - query = 'full text query' - relevant_media_ids = [2, 4] - - # Call the function - result = perform_full_text_search(query, relevant_media_ids) - - # Expected to include only media_id 2 and 4 - expected = [ - {'content': 'Full text document 2', 'metadata': {'media_id': 2}}, - {'content': 'Full text document 4', 'metadata': {'media_id': 4}}, - ] - self.assertEqual(result, expected) - - # Assert search_db was called with correct arguments - mock_search_db.assert_called_once_with( - query, ['content'], '', page=1, results_per_page=10) + ) + + keywords = ['geography', 'cities'] + result = fetch_relevant_media_ids(keywords) + assert sorted(result) == [1, 2, 3, 4] + + mock_fetch_keywords_for_media.assert_any_call('geography') + mock_fetch_keywords_for_media.assert_any_call('cities') + assert mock_fetch_keywords_for_media.call_count == 2 + + +def test_perform_full_text_search_with_relevant_ids(mocker): + """Test perform_full_text_search with relevant_ids provided.""" + # Create a transformed response matching the expected format + transformed_response = [ + {'content': 'Full text document 1', 'metadata': {'media_id': 1}}, + {'content': 'Full text document 3', 'metadata': {'media_id': 3}}, + ] + + # Mock the search functions mapping + search_function_mock = lambda query, fts_top_k, relevant_ids: transformed_response + search_functions_mock = { + "Media DB": search_function_mock + } + mocker.patch('App_Function_Libraries.RAG.RAG_Library_2.search_functions', search_functions_mock) + + query = 'full text query' + database_type = "Media DB" + relevant_ids = "1,3" + + result = perform_full_text_search(query, database_type, relevant_ids) + + expected = [ + {'content': 'Full text document 1', 'metadata': {'media_id': 1}}, + {'content': 'Full text document 3', 'metadata': {'media_id': 3}}, + ] + assert result == expected + + +def test_perform_full_text_search_without_relevant_ids(mocker): + """Test perform_full_text_search without relevant_ids.""" + # Create a transformed response matching the expected format + transformed_response = [ + {'content': 'Full text document 1', 'metadata': {'media_id': 1}}, + {'content': 'Full text document 2', 'metadata': {'media_id': 2}}, + ] + + # Mock the search functions mapping + search_function_mock = lambda query, fts_top_k, relevant_ids: transformed_response + search_functions_mock = { + "Media DB": search_function_mock + } + mocker.patch('App_Function_Libraries.RAG.RAG_Library_2.search_functions', search_functions_mock) + + query = 'full text query' + database_type = "Media DB" + relevant_ids = "" + + result = perform_full_text_search(query, database_type, relevant_ids) + + expected = [ + {'content': 'Full text document 1', 'metadata': {'media_id': 1}}, + {'content': 'Full text document 2', 'metadata': {'media_id': 2}}, + ] + assert result == expected + + +@pytest.mark.parametrize("database_type,search_module_path,mock_response", [ + ( + "Media DB", + 'App_Function_Libraries.DB.SQLite_DB.search_media_db', + [{'content': 'Media DB document 1', 'metadata': {'media_id': '1'}}] + ), + ( + "RAG Chat", + 'App_Function_Libraries.DB.RAG_QA_Chat_DB.search_rag_chat', + [{'content': 'RAG Chat document 1', 'metadata': {'media_id': '1'}}] + ), + ( + "RAG Notes", + 'App_Function_Libraries.DB.RAG_QA_Chat_DB.search_rag_notes', + [{'content': 'RAG Notes document 1', 'metadata': {'media_id': '1'}}] + ), + ( + "Character Chat", + 'App_Function_Libraries.DB.Character_Chat_DB.search_character_chat', + [{'content': 'Character Chat document 1', 'metadata': {'media_id': '1'}}] + ), + ( + "Character Cards", + 'App_Function_Libraries.DB.Character_Chat_DB.search_character_cards', + [{'content': 'Character Cards document 1', 'metadata': {'media_id': '1'}}] + ) +]) +def test_perform_full_text_search_different_db_types(mocker, database_type, search_module_path, mock_response): + """Test perform_full_text_search with different database types.""" + # Mock the search functions mapping with already transformed response + search_functions_mock = { + database_type: lambda query, fts_top_k, relevant_ids: mock_response + } + mocker.patch('App_Function_Libraries.RAG.RAG_Library_2.search_functions', search_functions_mock) + + query = 'test query' + relevant_ids = "1" + + result = perform_full_text_search(query, database_type, relevant_ids) + assert result == mock_response + + +def test_enhanced_rag_pipeline_success(mocker): + """Test enhanced_rag_pipeline with a successful flow.""" + # Mock config + mock_config = configparser.ConfigParser() + mock_config['Embeddings'] = {'provider': 'openai'} + mocker.patch('App_Function_Libraries.RAG.RAG_Library_2.config', mock_config) + + # Mock search functions + fts_result = [{'content': 'FTS result', 'id': 1}] + vector_result = [{'content': 'Vector result'}] + + mock_search = lambda *args, **kwargs: fts_result + search_functions_mock = { + "Media DB": mock_search + } + mocker.patch('App_Function_Libraries.RAG.RAG_Library_2.search_functions', search_functions_mock) + + mocker.patch( + 'App_Function_Libraries.RAG.RAG_Library_2.perform_vector_search', + return_value=vector_result + ) + + mocker.patch( + 'App_Function_Libraries.RAG.RAG_Library_2.generate_answer', + return_value='Generated answer' + ) + + # Mock relevant media IDs + mocker.patch( + 'App_Function_Libraries.RAG.RAG_Library_2.fetch_relevant_media_ids', + return_value=[1, 2, 3] + ) + + result = enhanced_rag_pipeline( + query='test query', + api_choice='OpenAI', + keywords='keyword1,keyword2', + database_types=["Media DB"] + ) + + # Check both vector and FTS results are in context + assert result['answer'] == 'Generated answer' + assert 'Vector result' in result['context'] + assert 'FTS result' in result['context'] + + +def test_enhanced_rag_pipeline_error_handling(mocker): + """Test enhanced_rag_pipeline error handling.""" + mock_config = configparser.ConfigParser() + mock_config['Embeddings'] = {'provider': 'openai'} + mocker.patch('App_Function_Libraries.RAG.RAG_Library_2.config', mock_config) + + mock_fetch_keywords_for_media = mocker.patch( + 'App_Function_Libraries.RAG.RAG_Library_2.fetch_relevant_media_ids', + side_effect=Exception("Fetch error") + ) + + result = enhanced_rag_pipeline( + query='test query', + api_choice='OpenAI', + keywords='keyword1', + database_types=["Media DB"] + ) + + assert "An error occurred" in result['answer'] + assert result['context'] == "" + + +def test_generate_answer_success(mocker): + """Test generate_answer with successful API call.""" + # Mock config + mock_config = configparser.ConfigParser() + mock_config['API'] = {'openai_api_key': 'test_key'} + mocker.patch( + 'App_Function_Libraries.RAG.RAG_Library_2.load_comprehensive_config', + return_value=mock_config + ) + + # Mock the summarization function + mock_summarize = mocker.patch( + 'App_Function_Libraries.Summarization.Summarization_General_Lib.summarize_with_openai', + return_value='API response' + ) + + result = generate_answer('OpenAI', 'Test context', 'Test query') + assert result == 'API response' if __name__ == '__main__': - unittest.main() \ No newline at end of file + pytest.main(['-v']) \ No newline at end of file diff --git a/Tests/RAG/test_enhanced_rag_pipeline.py b/Tests/RAG/test_enhanced_rag_pipeline.py index 6c8ec1bce..90ce16f92 100644 --- a/Tests/RAG/test_enhanced_rag_pipeline.py +++ b/Tests/RAG/test_enhanced_rag_pipeline.py @@ -19,11 +19,9 @@ class TestEnhancedRagPipeline(unittest.TestCase): @patch('App_Function_Libraries.RAG.RAG_Library_2.perform_vector_search') @patch('App_Function_Libraries.RAG.RAG_Library_2.perform_full_text_search') @patch('App_Function_Libraries.RAG.RAG_Library_2.generate_answer') - def test_enhanced_rag_pipeline(self, mock_generate_answer, mock_fts_search, mock_vector_search, mock_fetch_keywords): - """ - Test the enhanced_rag_pipeline function by mocking the dependent functions such as - vector search, full-text search, and external API calls. - """ + def test_enhanced_rag_pipeline(self, mock_generate_answer, mock_fts_search, mock_vector_search, + mock_fetch_keywords): + """Test the enhanced_rag_pipeline function with less strict string matching""" # Setup mock data query = "What is the capital of France?" @@ -53,15 +51,24 @@ def test_enhanced_rag_pipeline(self, mock_generate_answer, mock_fts_search, mock mock_vector_search.assert_called_once_with(query, [1, 2, 3]) mock_fts_search.assert_called_once_with(query, [1, 2, 3]) - # Check that generate_answer was called with the correct context and query - expected_context = "Paris is the capital of France.\nThe capital of France is Paris." - mock_generate_answer.assert_called_once_with(api_choice, expected_context, query) + # Instead of checking exact string match, check that both pieces are in the context + call_args = mock_generate_answer.call_args[0] # Get the args the mock was called with + actual_context = call_args[1] # Get the context string + + # Verify each piece of content is in the context + self.assertIn("Paris is the capital of France.", actual_context) + self.assertIn("The capital of France is Paris.", actual_context) + self.assertEqual(call_args[0], api_choice) # API choice should match + self.assertEqual(call_args[2], query) # Query should match # Validate the result structure self.assertIn("answer", result) self.assertIn("context", result) self.assertEqual(result["answer"], "Paris is the capital of France.") - self.assertEqual(result["context"], expected_context) + + # Verify both pieces are in the result context too + self.assertIn("Paris is the capital of France.", result["context"]) + self.assertIn("The capital of France is Paris.", result["context"]) if __name__ == '__main__': unittest.main() diff --git a/Tests/RAG_QA_Chat/test_notes_search.py b/Tests/RAG_QA_Chat/test_notes_search.py new file mode 100644 index 000000000..0a3fbaf87 --- /dev/null +++ b/Tests/RAG_QA_Chat/test_notes_search.py @@ -0,0 +1,135 @@ +# test_notes_search.py +# pytest test file that tests the search_notes_titles function from the DB_Manager module. +# +# Imports +import os +import sys +import pytest +from unittest.mock import MagicMock +import sqlite3 +from datetime import datetime +# +# Adjust the path to the parent directory of App_Function_Libraries +current_dir = os.path.dirname(os.path.abspath(__file__)) +parent_dir = os.path.abspath(os.path.join(current_dir, '..', '..')) +sys.path.append(parent_dir) +# +# Local Imports +from App_Function_Libraries.DB.DB_Manager import search_notes_titles +# +#################################################################################################### +# +# Test Functions + +@pytest.fixture +def db_connection(): + """Create an in-memory SQLite database and populate it with test data.""" + conn = sqlite3.connect(":memory:") + cursor = conn.cursor() + + # Create the main notes table + cursor.execute(""" + CREATE TABLE rag_qa_notes ( + id INTEGER PRIMARY KEY, + title TEXT, + content TEXT, + timestamp TEXT, + conversation_id TEXT + ) + """) + + # Create the FTS (Full-Text Search) virtual table + cursor.execute(""" + CREATE VIRTUAL TABLE rag_qa_notes_fts USING FTS5(title, content) + """) + + # Sample test data + notes = [ + (1, "First Note", "Content of the first note", datetime.now().isoformat(), "conv1"), + (2, "Second Note", "Content of the second note", datetime.now().isoformat(), "conv2"), + (3, "Another Note", "Content of another note", datetime.now().isoformat(), "conv3"), + (4, "Note Four", "Fourth note content", datetime.now().isoformat(), "conv4"), + (5, "Final Note", "This is the final note", datetime.now().isoformat(), "conv5") + ] + + # Insert data into the main table + cursor.executemany(""" + INSERT INTO rag_qa_notes (id, title, content, timestamp, conversation_id) + VALUES (?, ?, ?, ?, ?) + """, notes) + + # Insert data into the FTS table + for note in notes: + cursor.execute(""" + INSERT INTO rag_qa_notes_fts (rowid, title, content) + VALUES (?, ?, ?) + """, (note[0], note[1], note[2])) + + conn.commit() + yield conn + conn.close() + + +def test_search_notes_titles_with_search_term(db_connection): + """Test searching with a non-empty search term.""" + search_term = "Note" + results, total_pages, total_count = search_notes_titles(search_term, connection=db_connection) + + assert total_count == 5 + assert total_pages == 1 + assert len(results) == 5 + for result in results: + assert "Note" in result[1] + + +def test_search_notes_titles_empty_search_term(db_connection): + """Test searching with an empty search term, which should return all notes.""" + search_term = "" + results, total_pages, total_count = search_notes_titles(search_term, connection=db_connection) + + assert total_count == 5 + assert total_pages == 1 + assert len(results) == 5 + + +def test_search_notes_titles_pagination(db_connection): + """Test pagination functionality.""" + search_term = "" + results_per_page = 2 + + # First page + results, total_pages, total_count = search_notes_titles( + search_term, page=1, results_per_page=results_per_page, connection=db_connection) + assert total_count == 5 + assert total_pages == 3 + assert len(results) == 2 + + # Second page + results_page_2, _, _ = search_notes_titles( + search_term, page=2, results_per_page=results_per_page, connection=db_connection) + assert len(results_page_2) == 2 + + # Third page + results_page_3, _, _ = search_notes_titles( + search_term, page=3, results_per_page=results_per_page, connection=db_connection) + assert len(results_page_3) == 1 + + +def test_search_notes_titles_invalid_page(db_connection): + """Test that a ValueError is raised when an invalid page number is provided.""" + with pytest.raises(ValueError, match="Page number must be 1 or greater."): + search_notes_titles("test", page=0, connection=db_connection) + + +def test_search_notes_titles_db_error(): + """Test that a database error is properly raised and handled.""" + # Create a mock connection + mock_conn = MagicMock() + mock_cursor = mock_conn.cursor.return_value + # Set the side effect of the execute method to raise a database error + mock_cursor.execute.side_effect = sqlite3.Error("Test database error") + + # Now call the function with the mock connection + with pytest.raises(sqlite3.Error) as exc_info: + search_notes_titles("test", connection=mock_conn) + assert "Error searching notes: Test database error" in str(exc_info.value) diff --git a/Tests/SQLite_DB/test_chat_functions.py b/Tests/SQLite_DB/test_chat_functions.py deleted file mode 100644 index 8f7239684..000000000 --- a/Tests/SQLite_DB/test_chat_functions.py +++ /dev/null @@ -1,102 +0,0 @@ -import pytest -from App_Function_Libraries.DB.DB_Manager import ( - create_chat_conversation, - add_chat_message, - get_chat_messages, - update_chat_message, - delete_chat_message, - search_chat_conversations, - get_conversation_name -) - - -@pytest.fixture -def sample_conversation(empty_db): - conversation_id = create_chat_conversation(None, "Test Conversation") - return conversation_id - - -def test_create_chat_conversation(empty_db): - conversation_id = create_chat_conversation(None, "Test Conversation") - assert conversation_id is not None - assert isinstance(conversation_id, int) - - # Commenting out this test as get_conversation_name seems to be not implemented or returning None - # name = get_conversation_name(conversation_id) - # assert name == "Test Conversation" - - -def test_add_chat_message(empty_db, sample_conversation): - message_id = add_chat_message(sample_conversation, "user", "Hello, world!") - assert message_id is not None - assert isinstance(message_id, int) - - -def test_get_chat_messages(empty_db, sample_conversation): - add_chat_message(sample_conversation, "user", "Hello, world!") - add_chat_message(sample_conversation, "ai", "Hi there!") - messages = get_chat_messages(sample_conversation) - assert len(messages) == 2 - assert messages[0]['message'] == "Hello, world!" - assert messages[1]['message'] == "Hi there!" - - -def test_update_chat_message(empty_db, sample_conversation): - message_id = add_chat_message(sample_conversation, "user", "Hello, world!") - update_chat_message(message_id, "Updated message") - messages = get_chat_messages(sample_conversation) - assert len(messages) == 1 - assert messages[0]['message'] == "Updated message" - - -def test_delete_chat_message(empty_db, sample_conversation): - message_id = add_chat_message(sample_conversation, "user", "Hello, world!") - delete_chat_message(message_id) - messages = get_chat_messages(sample_conversation) - assert len(messages) == 0 - - -def test_search_chat_conversations(empty_db): - # Create conversations with names that will match our search queries - conv1_id = create_chat_conversation(None, "World Conversation") - conv2_id = create_chat_conversation(None, "Python Discussion") - conv3_id = create_chat_conversation(None, "Test Conversation") - - # Add messages (these won't affect the search results based on the current implementation) - add_chat_message(conv1_id, "user", "Hello, world!") - add_chat_message(conv2_id, "user", "Python is great") - add_chat_message(conv3_id, "user", "This is a test message") - - print(f"Created conversations: {conv1_id}, {conv2_id}, {conv3_id}") - - results = search_chat_conversations("World") - print(f"Search results for 'World': {results}") - assert len(results) > 0, "Search should return at least one result for 'World'" - assert any("World" in result['conversation_name'] for result in results) - - results = search_chat_conversations("Python") - print(f"Search results for 'Python': {results}") - assert len(results) > 0, "Search should return at least one result for 'Python'" - assert any("Python" in result['conversation_name'] for result in results) - - results = search_chat_conversations("Test") - print(f"Search results for 'Test': {results}") - assert len(results) > 0, "Search should return at least one result for 'Test'" - assert any("Test" in result['conversation_name'] for result in results) - - # Test partial matching - results = search_chat_conversations("Conver") - print(f"Search results for 'Conver': {results}") - assert len(results) > 0, "Search should return results for partial matches" - - # Add a catch-all search to see if any results are returned - all_results = search_chat_conversations("") - print(f"All search results: {all_results}") - assert len( - all_results) >= 3, "Search should return at least the conversations we just created when given an empty string" - - # Check if our newly created conversations are in the results - new_conversation_ids = {conv1_id, conv2_id, conv3_id} - found_conversations = set(result['id'] for result in all_results) - assert new_conversation_ids.issubset( - found_conversations), "All newly created conversations should be in the search results" \ No newline at end of file diff --git a/Tests/SQLite_DB/test_error_handling.py b/Tests/SQLite_DB/test_error_handling.py index 418cd7dc5..af57a0338 100644 --- a/Tests/SQLite_DB/test_error_handling.py +++ b/Tests/SQLite_DB/test_error_handling.py @@ -1,20 +1,47 @@ # tests/test_error_handling.py +from unittest.mock import patch, MagicMock + import pytest import os import sqlite3 -from App_Function_Libraries.DB.SQLite_DB import Database, DatabaseError, InputError, add_keyword, delete_keyword, add_media_with_keywords, sqlite_search_db +from App_Function_Libraries.DB.SQLite_DB import Database, DatabaseError, InputError, add_keyword, delete_keyword, \ + add_media_with_keywords, search_media_db @pytest.fixture def test_db(tmp_path): - db_file = tmp_path / "test.db" - db = Database(str(db_file)) + """Create a test database file""" + db_path = tmp_path / "test.db" + db = Database(str(db_path)) + + # Initialize database with required tables + with patch('App_Function_Libraries.DB.SQLite_DB.Database.table_exists', return_value=False): + with db.get_connection() as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS test ( + id INTEGER PRIMARY KEY + ) + """) + yield db - db.close_connection() # Ensure the connection is closed - try: - os.remove(db_file) - except PermissionError: - print(f"Warning: Unable to delete {db_file}. It may still be in use.") + + db.close_connection() + if os.path.exists(db_path): + try: + os.remove(db_path) + except PermissionError: + print(f"Warning: Unable to delete {db_path}. It may still be in use.") + + +@pytest.fixture +def mock_db(): + """Create a mock database for search tests""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.cursor.return_value = mock_cursor + + with patch('App_Function_Libraries.DB.DB_Manager.db.get_connection', return_value=mock_conn): + yield mock_conn, mock_cursor def test_execute_query_with_invalid_sql(test_db): @@ -33,43 +60,61 @@ def test_table_exists_nonexistent(test_db): def test_transaction_rollback(test_db): with test_db.get_connection() as conn: - conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY)") + conn.execute("CREATE TABLE IF NOT EXISTS test (id INTEGER PRIMARY KEY)") with pytest.raises(sqlite3.OperationalError): with test_db.transaction() as conn: conn.execute("INSERT INTO test (id) VALUES (?)", (1,)) conn.execute("INSERT INTO nonexistent_table (id) VALUES (?)", (1,)) - # Verify the transaction was rolled back result = test_db.execute_query("SELECT COUNT(*) FROM test") assert result[0][0] == 0 +def test_search_media_db_with_invalid_page(mock_db): + """Test search_media_db with invalid page number""" + with pytest.raises(ValueError, match="Page number must be 1 or greater."): + search_media_db("query", ['title'], "", page=0) + + +def test_search_media_db_with_none_values(mock_db): + """Test search_media_db with None values""" + mock_conn, mock_cursor = mock_db + mock_cursor.fetchall.return_value = [] + + result = search_media_db(None, ['title'], "") + assert isinstance(result, list) + + +def test_search_media_db_with_empty_search_fields(mock_db): + """Test search_media_db with empty search fields list""" + mock_conn, mock_cursor = mock_db + mock_cursor.fetchall.return_value = [] + + result = search_media_db("query", [], "") + assert isinstance(result, list) + + +def test_search_media_db_with_invalid_fields(mock_db): + """Test search_media_db with invalid search fields""" + mock_conn, mock_cursor = mock_db + mock_cursor.execute.side_effect = sqlite3.OperationalError + + with pytest.raises(sqlite3.OperationalError): + search_media_db("query", ['nonexistent_field'], "") + + def test_add_keyword_with_invalid_data(): with pytest.raises(AttributeError): add_keyword(None) + def test_delete_nonexistent_keyword(): result = delete_keyword("nonexistent_keyword") assert "not found" in result + def test_add_media_with_invalid_data(): with pytest.raises(InputError): add_media_with_keywords(None, None, None, None, None, None, None, None, None, None) -def test_sqlite_search_db_with_invalid_data(): - # Test with invalid page number - with pytest.raises(ValueError): - sqlite_search_db("query", ['title'], "", 0, 10) - - # Test with None values (should not raise an exception) - result = sqlite_search_db(None, ['title'], "", 1, 10) - assert isinstance(result, list) - - # Test with empty search fields - result = sqlite_search_db("query", [], "", 1, 10) - assert isinstance(result, list) - - # Test with invalid search fields - with pytest.raises(sqlite3.OperationalError): - sqlite_search_db("query", ['nonexistent_field'], "", 1, 10) diff --git a/Tests/SQLite_DB/test_search_functions.py b/Tests/SQLite_DB/test_search_functions.py index 41969ba11..1802f7e66 100644 --- a/Tests/SQLite_DB/test_search_functions.py +++ b/Tests/SQLite_DB/test_search_functions.py @@ -5,7 +5,7 @@ from typing import List, Tuple # # Updated import statement -from App_Function_Libraries.DB.DB_Manager import sqlite_search_db, search_media_database, db +from App_Function_Libraries.DB.DB_Manager import search_media_db, search_media_database, db # # #################################################################################################### @@ -17,108 +17,137 @@ import sqlite3 from contextlib import contextmanager -from App_Function_Libraries.DB.DB_Manager import sqlite_search_db, search_media_database, Database # Modify the functions to accept a connection parameter for testing -def sqlite_search_db_testable(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 10, connection=None): - if connection is None: - with db.get_connection() as conn: - return sqlite_search_db(search_query, search_fields, keywords, page, results_per_page) - else: - # Use the provided connection for testing - return sqlite_search_db(search_query, search_fields, keywords, page, results_per_page, connection=connection) - -def search_media_database_testable(query: str, connection=None): - if connection is None: - with db.get_connection() as conn: - return search_media_database(query) - else: - # Use the provided connection for testing - return search_media_database(query, connection=connection) - @pytest.fixture -def mock_connection(): - mock_conn = MagicMock() - mock_cursor = MagicMock() - mock_conn.cursor.return_value = mock_cursor - return mock_conn, mock_cursor - -def test_sqlite_search_db(mock_connection): - mock_conn, mock_cursor = mock_connection - mock_cursor.fetchall.return_value = [ +def mock_db(): + """Create a mock database for search tests""" + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + cursor.connection = conn + return conn, cursor + + +def test_search_media_db(mock_db): + conn, cursor = mock_db + test_results = [ (1, 'http://example.com', 'Test Title', 'video', 'content', 'author', '2023-01-01', 'prompt', 'summary') ] - results = sqlite_search_db_testable('Test', ['title'], '', page=1, results_per_page=10, connection=mock_conn) + # Set up cursor mock + cursor.fetchall.return_value = test_results + + with patch('App_Function_Libraries.DB.DB_Manager.Database.get_connection') as mock_get_conn: + with patch('App_Function_Libraries.DB.DB_Manager.Database.execute_query') as mock_execute: + # Configure mocks + mock_get_conn.return_value.__enter__.return_value = conn + mock_execute.return_value = test_results + + # Execute test + results = search_media_db('Test', ['title'], '') + + # Verify results + assert len(results) == 1 + assert results[0][2] == 'Test Title' + + # Verify SQL query + actual_query = cursor.execute.call_args[0][0] + actual_params = cursor.execute.call_args[0][1] + assert 'SELECT DISTINCT Media.id, Media.url, Media.title' in actual_query + assert 'WHERE Media.title LIKE ?' in actual_query + assert '%Test%' in actual_params - assert len(results) == 1 - assert results[0][2] == 'Test Title' - mock_cursor.execute.assert_called() - call_args = mock_cursor.execute.call_args[0] - assert 'SELECT DISTINCT Media.id, Media.url, Media.title' in call_args[0] - assert 'WHERE Media.title LIKE ?' in call_args[0] - assert '%Test%' in call_args[1] -def test_sqlite_search_db_with_keywords(mock_connection): - mock_conn, mock_cursor = mock_connection - mock_cursor.fetchall.return_value = [ +def test_search_media_db_with_keywords(mock_db): + conn, cursor = mock_db + test_results = [ (1, 'http://example.com', 'Test Title', 'video', 'content', 'author', '2023-01-01', 'prompt', 'summary') ] - results = sqlite_search_db_testable('Test', ['title'], 'keyword1,keyword2', page=1, results_per_page=10, connection=mock_conn) + cursor.fetchall.return_value = test_results - assert len(results) == 1 - mock_cursor.execute.assert_called() - call_args = mock_cursor.execute.call_args[0] - assert 'EXISTS (SELECT 1 FROM MediaKeywords mk JOIN Keywords k ON mk.keyword_id = k.id WHERE mk.media_id = Media.id AND k.keyword LIKE ?)' in call_args[0] - assert '%keyword1%' in call_args[1] - assert '%keyword2%' in call_args[1] + with patch('App_Function_Libraries.DB.DB_Manager.Database.get_connection') as mock_get_conn: + with patch('App_Function_Libraries.DB.DB_Manager.Database.execute_query') as mock_execute: + mock_get_conn.return_value.__enter__.return_value = conn + mock_execute.return_value = test_results -def test_sqlite_search_db_pagination(mock_connection): - mock_conn, mock_cursor = mock_connection - mock_cursor.fetchall.return_value = [ - (2, 'http://example2.com', 'Second Title', 'article', 'content2', 'author2', '2023-01-02', 'prompt2', 'summary2') - ] + results = search_media_db('Test', ['title'], 'keyword1,keyword2') - results = sqlite_search_db_testable('', ['title'], '', page=2, results_per_page=1, connection=mock_conn) + assert len(results) == 1 + actual_query = cursor.execute.call_args[0][0] + actual_params = cursor.execute.call_args[0][1] + assert 'EXISTS (SELECT 1 FROM MediaKeywords mk JOIN Keywords k ON mk.keyword_id = k.id' in actual_query + assert '%keyword1%' in actual_params + assert '%keyword2%' in actual_params - assert len(results) == 1 - assert results[0][2] == 'Second Title' - mock_cursor.execute.assert_called() - call_args = mock_cursor.execute.call_args[0] - assert 'LIMIT ? OFFSET ?' in call_args[0] - assert call_args[1][-2:] == [1, 1] # LIMIT 1 OFFSET 1 -def test_sqlite_search_db_invalid_page(): - with pytest.raises(ValueError, match="Page number must be 1 or greater."): - sqlite_search_db_testable('Test', ['title'], '', page=0, results_per_page=10) +def test_search_media_db_pagination(mock_db): + conn, cursor = mock_db + + page1_results = [ + (1, 'http://example1.com', 'First Title', 'video', 'content1', 'author1', '2023-01-01', 'prompt1', 'summary1')] + page2_results = [(2, 'http://example2.com', 'Second Title', 'article', 'content2', 'author2', '2023-01-02', + 'prompt2', 'summary2')] + + with patch('App_Function_Libraries.DB.DB_Manager.Database.get_connection') as mock_get_conn: + with patch('App_Function_Libraries.DB.DB_Manager.Database.execute_query') as mock_execute: + mock_get_conn.return_value.__enter__.return_value = conn + mock_execute.side_effect = [page1_results, page2_results] + cursor.fetchall.side_effect = [page1_results, page2_results] + results_page_1 = search_media_db('', ['title'], '', page=1, results_per_page=1) + results_page_2 = search_media_db('', ['title'], '', page=2, results_per_page=1) -def test_search_media_database(mock_connection): - mock_conn, mock_cursor = mock_connection - mock_cursor.fetchall.return_value = [ + assert len(results_page_1) == 1 + assert len(results_page_2) == 1 + assert results_page_1[0][2] == 'First Title' + assert results_page_2[0][2] == 'Second Title' + assert results_page_1 != results_page_2 + + +def test_search_media_database(mock_db): + conn, cursor = mock_db + test_results = [ (1, 'Test Title', 'http://example.com') ] - results = search_media_database('Test', mock_conn) + with patch('App_Function_Libraries.DB.DB_Manager.Database.get_connection') as mock_get_conn: + with patch('App_Function_Libraries.DB.DB_Manager.Database.execute_query') as mock_execute: + mock_get_conn.return_value.__enter__.return_value = conn + mock_execute.return_value = test_results + cursor.fetchall.return_value = test_results + + results = search_media_database('Test') + + assert len(results) == 1 + assert results[0] == (1, 'Test Title', 'http://example.com') + actual_query = cursor.execute.call_args[0][0] + actual_params = cursor.execute.call_args[0][1] + assert 'SELECT id, title, url FROM Media WHERE title LIKE ?' in actual_query + assert '%Test%' in actual_params - assert len(results) == 1 - assert results[0] == (1, 'Test Title', 'http://example.com') - mock_cursor.execute.assert_called_with( - "SELECT id, title, url FROM Media WHERE title LIKE ?", - ('%Test%',) - ) +def test_search_media_database_error(mock_db): + conn, cursor = mock_db + test_error = sqlite3.Error("Test database error") -def test_search_media_database_error(mock_connection): - mock_conn, mock_cursor = mock_connection - mock_cursor.execute.side_effect = sqlite3.Error("Test database error") + with patch('App_Function_Libraries.DB.DB_Manager.Database.get_connection') as mock_get_conn: + with patch('App_Function_Libraries.DB.DB_Manager.Database.execute_query') as mock_execute: + mock_get_conn.return_value.__enter__.return_value = conn + mock_execute.side_effect = test_error + cursor.execute.side_effect = test_error - with pytest.raises(Exception) as exc_info: - search_media_database('Test', connection=mock_conn) + with pytest.raises(Exception) as exc_info: + search_media_database('Test') - assert str(exc_info.value) == "Error searching media database: Test database error" + assert str(exc_info.value) == "Error searching media database: Test database error" + + +def test_search_media_db_invalid_page(): + with pytest.raises(ValueError, match="Page number must be 1 or greater."): + search_media_db('Test', ['title'], '', page=0, results_per_page=10) # # End of File -#################################################################################################### \ No newline at end of file +#################################################################################################### diff --git a/Tests/SQLite_DB/test_sqlite_db.py b/Tests/SQLite_DB/test_sqlite_db.py index 992798415..6c05b3165 100644 --- a/Tests/SQLite_DB/test_sqlite_db.py +++ b/Tests/SQLite_DB/test_sqlite_db.py @@ -116,7 +116,7 @@ def test_create_tables(db): assert db.table_exists('Media') assert db.table_exists('Keywords') assert db.table_exists('MediaKeywords') - assert db.table_exists('ChatConversations') + def test_multiple_connections(db): diff --git a/requirements.txt b/requirements.txt index ad95b2e62..d2157ea2b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,7 @@ faster_whisper fire FlashRank fugashi +genanki # well fuck gradio. again. gradio==4.44.1 html2text