diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index c470f25bb..2ec85b147 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -42,6 +42,12 @@ jobs: cd ./Tests/RAG pytest test_RAG_Library_2.py + - name: Test RAG Notes functions with pytest + run: | + pwd + cd ./Tests/RAG_QA_Chat + pytest test_notes_search.py + - name: Test SQLite lib functions with pytest run: | pwd diff --git a/App_Function_Libraries/Audio/Audio_Files.py b/App_Function_Libraries/Audio/Audio_Files.py index b3eabbde3..1fa0b1afa 100644 --- a/App_Function_Libraries/Audio/Audio_Files.py +++ b/App_Function_Libraries/Audio/Audio_Files.py @@ -117,16 +117,15 @@ def process_audio_files(audio_urls, audio_file, whisper_model, api_name, api_key progress = [] all_transcriptions = [] all_summaries = [] - #v2 + temp_files = [] # Keep track of temporary files + def format_transcription_with_timestamps(segments): if keep_timestamps: formatted_segments = [] for segment in segments: start = segment.get('Time_Start', 0) end = segment.get('Time_End', 0) - text = segment.get('Text', '').strip() # Ensure text is stripped of leading/trailing spaces - - # Add the formatted timestamp and text to the list, followed by a newline + text = segment.get('Text', '').strip() formatted_segments.append(f"[{start:.2f}-{end:.2f}] {text}") # Join the segments with a newline to ensure proper formatting @@ -191,205 +190,64 @@ def convert_mp3_to_wav(mp3_file_path): 'language': chunk_language } - # Process multiple URLs - urls = [url.strip() for url in audio_urls.split('\n') if url.strip()] - - for i, url in enumerate(urls): - update_progress(f"Processing URL {i + 1}/{len(urls)}: {url}") - - # Download and process audio file - audio_file_path = download_audio_file(url, use_cookies, cookies) - if not os.path.exists(audio_file_path): - update_progress(f"Downloaded file not found: {audio_file_path}") - failed_count += 1 - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - continue - - temp_files.append(audio_file_path) - update_progress("Audio file downloaded successfully.") - - # Re-encode MP3 to fix potential issues - reencoded_mp3_path = reencode_mp3(audio_file_path) - if not os.path.exists(reencoded_mp3_path): - update_progress(f"Re-encoded file not found: {reencoded_mp3_path}") - failed_count += 1 - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - continue - - temp_files.append(reencoded_mp3_path) - - # Convert re-encoded MP3 to WAV - wav_file_path = convert_mp3_to_wav(reencoded_mp3_path) - if not os.path.exists(wav_file_path): - update_progress(f"Converted WAV file not found: {wav_file_path}") - failed_count += 1 - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - continue - - temp_files.append(wav_file_path) - - # Initialize transcription - transcription = "" - - # Transcribe audio - if diarize: - segments = speech_to_text(wav_file_path, whisper_model=whisper_model, diarize=True) - else: - segments = speech_to_text(wav_file_path, whisper_model=whisper_model) - - # Handle segments nested under 'segments' key - if isinstance(segments, dict) and 'segments' in segments: - segments = segments['segments'] - - if isinstance(segments, list): - # Log first 5 segments for debugging - logging.debug(f"Segments before formatting: {segments[:5]}") - transcription = format_transcription_with_timestamps(segments) - logging.debug(f"Formatted transcription (first 500 chars): {transcription[:500]}") - update_progress("Audio transcribed successfully.") - else: - update_progress("Unexpected segments format received from speech_to_text.") - logging.error(f"Unexpected segments format: {segments}") - failed_count += 1 - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - continue - - if not transcription.strip(): - update_progress("Transcription is empty.") - failed_count += 1 - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - else: - # Apply chunking - chunked_text = improved_chunking_process(transcription, chunk_options) - - # Summarize - logging.debug(f"Audio Transcription API Name: {api_name}") - if api_name: - try: - summary = perform_summarization(api_name, chunked_text, custom_prompt_input, api_key) - update_progress("Audio summarized successfully.") - except Exception as e: - logging.error(f"Error during summarization: {str(e)}") - summary = "Summary generation failed" - failed_count += 1 - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - else: - summary = "No summary available (API not provided)" + # Process URLs if provided + if audio_urls: + urls = [url.strip() for url in audio_urls.split('\n') if url.strip()] + for i, url in enumerate(urls): + try: + update_progress(f"Processing URL {i + 1}/{len(urls)}: {url}") - all_transcriptions.append(transcription) - all_summaries.append(summary) + # Download and process audio file + audio_file_path = download_audio_file(url, use_cookies, cookies) + if not audio_file_path: + raise FileNotFoundError(f"Failed to download audio from URL: {url}") - # Use custom_title if provided, otherwise use the original filename - title = custom_title if custom_title else os.path.basename(wav_file_path) - - # Add to database - add_media_with_keywords( - url=url, - title=title, - media_type='audio', - content=transcription, - keywords=custom_keywords, - prompt=custom_prompt_input, - summary=summary, - transcription_model=whisper_model, - author="Unknown", - ingestion_date=datetime.now().strftime('%Y-%m-%d') - ) - update_progress("Audio file processed and added to database.") - processed_count += 1 - log_counter( - metric_name="audio_files_processed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - - # Process uploaded file if provided - if audio_file: - url = generate_unique_id() - if os.path.getsize(audio_file.name) > MAX_FILE_SIZE: - update_progress( - f"Uploaded file size exceeds the maximum limit of {MAX_FILE_SIZE / (1024 * 1024):.2f}MB. Skipping this file.") - else: - try: - # Re-encode MP3 to fix potential issues - reencoded_mp3_path = reencode_mp3(audio_file.name) - if not os.path.exists(reencoded_mp3_path): - update_progress(f"Re-encoded file not found: {reencoded_mp3_path}") - return update_progress("Processing failed: Re-encoded file not found"), "", "" + temp_files.append(audio_file_path) + # Process the audio file + reencoded_mp3_path = reencode_mp3(audio_file_path) temp_files.append(reencoded_mp3_path) - # Convert re-encoded MP3 to WAV wav_file_path = convert_mp3_to_wav(reencoded_mp3_path) - if not os.path.exists(wav_file_path): - update_progress(f"Converted WAV file not found: {wav_file_path}") - return update_progress("Processing failed: Converted WAV file not found"), "", "" - temp_files.append(wav_file_path) - # Initialize transcription - transcription = "" - - if diarize: - segments = speech_to_text(wav_file_path, whisper_model=whisper_model, diarize=True) - else: - segments = speech_to_text(wav_file_path, whisper_model=whisper_model) + # Transcribe audio + segments = speech_to_text(wav_file_path, whisper_model=whisper_model, diarize=diarize) - # Handle segments nested under 'segments' key + # Handle segments format if isinstance(segments, dict) and 'segments' in segments: segments = segments['segments'] - if isinstance(segments, list): - transcription = format_transcription_with_timestamps(segments) - else: - update_progress("Unexpected segments format received from speech_to_text.") - logging.error(f"Unexpected segments format: {segments}") + if not isinstance(segments, list): + raise ValueError("Unexpected segments format received from speech_to_text") - chunked_text = improved_chunking_process(transcription, chunk_options) + transcription = format_transcription_with_timestamps(segments) + if not transcription.strip(): + raise ValueError("Empty transcription generated") - logging.debug(f"Audio Transcription API Name: {api_name}") - if api_name: + # Initialize summary with default value + summary = "No summary available" + + # Attempt summarization if API is provided + if api_name and api_name.lower() != "none": try: - summary = perform_summarization(api_name, chunked_text, custom_prompt_input, api_key) + chunked_text = improved_chunking_process(transcription, chunk_options) + summary_result = perform_summarization(api_name, chunked_text, custom_prompt_input, api_key) + if summary_result: + summary = summary_result update_progress("Audio summarized successfully.") except Exception as e: - logging.error(f"Error during summarization: {str(e)}") + logging.error(f"Summarization failed: {str(e)}") summary = "Summary generation failed" - else: - summary = "No summary available (API not provided)" + # Add to results all_transcriptions.append(transcription) all_summaries.append(summary) - # Use custom_title if provided, otherwise use the original filename + # Add to database title = custom_title if custom_title else os.path.basename(wav_file_path) - add_media_with_keywords( - url="Uploaded File", + url=url, title=title, media_type='audio', content=transcription, @@ -400,65 +258,112 @@ def convert_mp3_to_wav(mp3_file_path): author="Unknown", ingestion_date=datetime.now().strftime('%Y-%m-%d') ) - update_progress("Uploaded file processed and added to database.") + processed_count += 1 - log_counter( - metric_name="audio_files_processed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) + update_progress(f"Successfully processed URL {i + 1}") + log_counter("audio_files_processed_total", 1, {"whisper_model": whisper_model, "api_name": api_name}) + except Exception as e: - update_progress(f"Error processing uploaded file: {str(e)}") - logging.error(f"Error processing uploaded file: {str(e)}") failed_count += 1 - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - return update_progress("Processing failed: Error processing uploaded file"), "", "" - # Final cleanup - if not keep_original: - cleanup_files() + update_progress(f"Failed to process URL {i + 1}: {str(e)}") + log_counter("audio_files_failed_total", 1, {"whisper_model": whisper_model, "api_name": api_name}) + continue - end_time = time.time() - processing_time = end_time - start_time - # Log processing time - log_histogram( - metric_name="audio_processing_time_seconds", - value=processing_time, - labels={"whisper_model": whisper_model, "api_name": api_name} - ) + # Process uploaded file if provided + if audio_file: + try: + update_progress("Processing uploaded file...") + if os.path.getsize(audio_file.name) > MAX_FILE_SIZE: + raise ValueError(f"File size exceeds maximum limit of {MAX_FILE_SIZE / (1024 * 1024):.2f}MB") - # Optionally, log total counts - log_counter( - metric_name="total_audio_files_processed", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=processed_count - ) + reencoded_mp3_path = reencode_mp3(audio_file.name) + temp_files.append(reencoded_mp3_path) - log_counter( - metric_name="total_audio_files_failed", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=failed_count - ) + wav_file_path = convert_mp3_to_wav(reencoded_mp3_path) + temp_files.append(wav_file_path) + + # Transcribe audio + segments = speech_to_text(wav_file_path, whisper_model=whisper_model, diarize=diarize) + + if isinstance(segments, dict) and 'segments' in segments: + segments = segments['segments'] + + if not isinstance(segments, list): + raise ValueError("Unexpected segments format received from speech_to_text") + transcription = format_transcription_with_timestamps(segments) + if not transcription.strip(): + raise ValueError("Empty transcription generated") + + # Initialize summary with default value + summary = "No summary available" - final_progress = update_progress("All processing complete.") - final_transcriptions = "\n\n".join(all_transcriptions) - final_summaries = "\n\n".join(all_summaries) + # Attempt summarization if API is provided + if api_name and api_name.lower() != "none": + try: + chunked_text = improved_chunking_process(transcription, chunk_options) + summary_result = perform_summarization(api_name, chunked_text, custom_prompt_input, api_key) + if summary_result: + summary = summary_result + update_progress("Audio summarized successfully.") + except Exception as e: + logging.error(f"Summarization failed: {str(e)}") + summary = "Summary generation failed" + + # Add to results + all_transcriptions.append(transcription) + all_summaries.append(summary) + + # Add to database + title = custom_title if custom_title else os.path.basename(wav_file_path) + add_media_with_keywords( + url="Uploaded File", + title=title, + media_type='audio', + content=transcription, + keywords=custom_keywords, + prompt=custom_prompt_input, + summary=summary, + transcription_model=whisper_model, + author="Unknown", + ingestion_date=datetime.now().strftime('%Y-%m-%d') + ) + + processed_count += 1 + update_progress("Successfully processed uploaded file") + log_counter("audio_files_processed_total", 1, {"whisper_model": whisper_model, "api_name": api_name}) + + except Exception as e: + failed_count += 1 + update_progress(f"Failed to process uploaded file: {str(e)}") + log_counter("audio_files_failed_total", 1, {"whisper_model": whisper_model, "api_name": api_name}) + + # Cleanup temporary files + if not keep_original: + cleanup_files() + + # Log processing metrics + processing_time = time.time() - start_time + log_histogram("audio_processing_time_seconds", processing_time, + {"whisper_model": whisper_model, "api_name": api_name}) + log_counter("total_audio_files_processed", processed_count, + {"whisper_model": whisper_model, "api_name": api_name}) + log_counter("total_audio_files_failed", failed_count, + {"whisper_model": whisper_model, "api_name": api_name}) + + # Prepare final output + final_progress = update_progress(f"Processing complete. Processed: {processed_count}, Failed: {failed_count}") + final_transcriptions = "\n\n".join(all_transcriptions) if all_transcriptions else "No transcriptions available" + final_summaries = "\n\n".join(all_summaries) if all_summaries else "No summaries available" return final_progress, final_transcriptions, final_summaries except Exception as e: - logging.error(f"Error processing audio files: {str(e)}") - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - cleanup_files() - return update_progress(f"Processing failed: {str(e)}"), "", "" + logging.error(f"Error in process_audio_files: {str(e)}") + log_counter("audio_files_failed_total", 1, {"whisper_model": whisper_model, "api_name": api_name}) + if not keep_original: + cleanup_files() + return update_progress(f"Processing failed: {str(e)}"), "No transcriptions available", "No summaries available" def format_transcription_with_timestamps(segments, keep_timestamps): diff --git a/App_Function_Libraries/Benchmarks_Evaluations/ms_g_eval.py b/App_Function_Libraries/Benchmarks_Evaluations/ms_g_eval.py index a17387980..a6c7651d3 100644 --- a/App_Function_Libraries/Benchmarks_Evaluations/ms_g_eval.py +++ b/App_Function_Libraries/Benchmarks_Evaluations/ms_g_eval.py @@ -24,7 +24,7 @@ wait_random_exponential, ) -from App_Function_Libraries.Chat import chat_api_call +from App_Function_Libraries.Chat.Chat_Functions import chat_api_call # ####################################################################################################################### diff --git a/App_Function_Libraries/Books/Book_Ingestion_Lib.py b/App_Function_Libraries/Books/Book_Ingestion_Lib.py index 488d1ba3e..72a9c2f18 100644 --- a/App_Function_Libraries/Books/Book_Ingestion_Lib.py +++ b/App_Function_Libraries/Books/Book_Ingestion_Lib.py @@ -385,109 +385,103 @@ def process_markdown_content(markdown_content, file_path, title, author, keyword return f"Document '{title}' imported successfully. Database result: {result}" -def import_file_handler(file, - title, - author, - keywords, - system_prompt, - custom_prompt, - auto_summarize, - api_name, - api_key, - max_chunk_size, - chunk_overlap, - custom_chapter_pattern - ): +def import_file_handler(files, + author, + keywords, + system_prompt, + custom_prompt, + auto_summarize, + api_name, + api_key, + max_chunk_size, + chunk_overlap, + custom_chapter_pattern): try: - log_counter("file_import_attempt", labels={"file_name": file.name}) - - # Handle max_chunk_size - if isinstance(max_chunk_size, str): - max_chunk_size = int(max_chunk_size) if max_chunk_size.strip() else 4000 - elif not isinstance(max_chunk_size, int): - max_chunk_size = 4000 # Default value if not a string or int - - # Handle chunk_overlap - if isinstance(chunk_overlap, str): - chunk_overlap = int(chunk_overlap) if chunk_overlap.strip() else 0 - elif not isinstance(chunk_overlap, int): - chunk_overlap = 0 # Default value if not a string or int - - chunk_options = { - 'method': 'chapter', - 'max_size': max_chunk_size, - 'overlap': chunk_overlap, - 'custom_chapter_pattern': custom_chapter_pattern if custom_chapter_pattern else None - } + if not files: + return "No files uploaded." - if file is None: - log_counter("file_import_error", labels={"error": "No file uploaded"}) - return "No file uploaded." + # Convert single file to list for consistent processing + if not isinstance(files, list): + files = [files] - file_path = file.name - if not os.path.exists(file_path): - log_counter("file_import_error", labels={"error": "File not found", "file_name": file.name}) - return "Uploaded file not found." + results = [] + for file in files: + log_counter("file_import_attempt", labels={"file_name": file.name}) - start_time = datetime.now() + # Handle max_chunk_size and chunk_overlap + chunk_size = int(max_chunk_size) if isinstance(max_chunk_size, (str, int)) else 4000 + overlap = int(chunk_overlap) if isinstance(chunk_overlap, (str, int)) else 0 - if file_path.lower().endswith('.epub'): - status = import_epub( - file_path, - title, - author, - keywords, - custom_prompt=custom_prompt, - system_prompt=system_prompt, - summary=None, - auto_summarize=auto_summarize, - api_name=api_name, - api_key=api_key, - chunk_options=chunk_options, - custom_chapter_pattern=custom_chapter_pattern - ) - log_counter("epub_import_success", labels={"file_name": file.name}) - result = f"📚 EPUB Imported Successfully:\n{status}" - elif file.name.lower().endswith('.zip'): - status = process_zip_file( - zip_file=file, - title=title, - author=author, - keywords=keywords, - custom_prompt=custom_prompt, - system_prompt=system_prompt, - summary=None, - auto_summarize=auto_summarize, - api_name=api_name, - api_key=api_key, - chunk_options=chunk_options - ) - log_counter("zip_import_success", labels={"file_name": file.name}) - result = f"📦 ZIP Processed Successfully:\n{status}" - elif file.name.lower().endswith(('.chm', '.html', '.pdf', '.xml', '.opml')): - file_type = file.name.split('.')[-1].upper() - log_counter("unsupported_file_type", labels={"file_type": file_type}) - result = f"{file_type} file import is not yet supported." - else: - log_counter("unsupported_file_type", labels={"file_type": file.name.split('.')[-1]}) - result = "❌ Unsupported file type. Please upload an `.epub` file or a `.zip` file containing `.epub` files." + chunk_options = { + 'method': 'chapter', + 'max_size': chunk_size, + 'overlap': overlap, + 'custom_chapter_pattern': custom_chapter_pattern if custom_chapter_pattern else None + } - end_time = datetime.now() - processing_time = (end_time - start_time).total_seconds() - log_histogram("file_import_duration", processing_time, labels={"file_name": file.name}) + file_path = file.name + if not os.path.exists(file_path): + results.append(f"❌ File not found: {file.name}") + continue - return result + start_time = datetime.now() + + # Extract title from filename + title = os.path.splitext(os.path.basename(file_path))[0] + + if file_path.lower().endswith('.epub'): + status = import_epub( + file_path, + title=title, # Use filename as title + author=author, + keywords=keywords, + custom_prompt=custom_prompt, + system_prompt=system_prompt, + summary=None, + auto_summarize=auto_summarize, + api_name=api_name, + api_key=api_key, + chunk_options=chunk_options, + custom_chapter_pattern=custom_chapter_pattern + ) + log_counter("epub_import_success", labels={"file_name": file.name}) + results.append(f"📚 {file.name}: {status}") + + elif file_path.lower().endswith('.zip'): + status = process_zip_file( + zip_file=file, + title=None, # Let each file use its own name + author=author, + keywords=keywords, + custom_prompt=custom_prompt, + system_prompt=system_prompt, + summary=None, + auto_summarize=auto_summarize, + api_name=api_name, + api_key=api_key, + chunk_options=chunk_options + ) + log_counter("zip_import_success", labels={"file_name": file.name}) + results.append(f"📦 {file.name}: {status}") + else: + results.append(f"❌ Unsupported file type: {file.name}") + continue + + end_time = datetime.now() + processing_time = (end_time - start_time).total_seconds() + log_histogram("file_import_duration", processing_time, labels={"file_name": file.name}) + + return "\n\n".join(results) except ValueError as ve: logging.exception(f"Error parsing input values: {str(ve)}") - log_counter("file_import_error", labels={"error": "Invalid input", "file_name": file.name}) return f"❌ Error: Invalid input for chunk size or overlap. Please enter valid numbers." except Exception as e: logging.exception(f"Error during file import: {str(e)}") - log_counter("file_import_error", labels={"error": str(e), "file_name": file.name}) return f"❌ Error during import: {str(e)}" + def read_epub(file_path): """ Reads and extracts text from an EPUB file. @@ -568,9 +562,9 @@ def ingest_text_file(file_path, title=None, author=None, keywords=None): # Add the text file to the database add_media_with_keywords( - url=file_path, + url="its_a_book", title=title, - media_type='document', + media_type='book', content=content, keywords=keywords, prompt='No prompt for text files', diff --git a/App_Function_Libraries/Chat.py b/App_Function_Libraries/Chat/Chat_Functions.py similarity index 86% rename from App_Function_Libraries/Chat.py rename to App_Function_Libraries/Chat/Chat_Functions.py index a86ab66cb..f595d40e7 100644 --- a/App_Function_Libraries/Chat.py +++ b/App_Function_Libraries/Chat/Chat_Functions.py @@ -1,4 +1,4 @@ -# Chat.py +# Chat_Functions.py # Chat functions for interacting with the LLMs as chatbots import base64 # Imports @@ -6,6 +6,7 @@ import logging import os import re +import sqlite3 import tempfile import time from datetime import datetime @@ -14,7 +15,8 @@ # External Imports # # Local Imports -from App_Function_Libraries.DB.DB_Manager import get_conversation_name, save_chat_history_to_database +from App_Function_Libraries.DB.DB_Manager import start_new_conversation, delete_messages_in_conversation, save_message +from App_Function_Libraries.DB.RAG_QA_Chat_DB import get_db_connection, get_conversation_name from App_Function_Libraries.LLM_API_Calls import chat_with_openai, chat_with_anthropic, chat_with_cohere, \ chat_with_groq, chat_with_openrouter, chat_with_deepseek, chat_with_mistral, chat_with_huggingface from App_Function_Libraries.LLM_API_Calls_Local import chat_with_aphrodite, chat_with_local_llm, chat_with_ollama, \ @@ -27,6 +29,16 @@ # # Functions: +def approximate_token_count(history): + total_text = '' + for user_msg, bot_msg in history: + if user_msg: + total_text += user_msg + ' ' + if bot_msg: + total_text += bot_msg + ' ' + total_tokens = len(total_text.split()) + return total_tokens + def chat_api_call(api_endpoint, api_key, input_data, prompt, temp, system_message=None): log_counter("chat_api_call_attempt", labels={"api_endpoint": api_endpoint}) start_time = time.time() @@ -173,56 +185,58 @@ def save_chat_history_to_db_wrapper(chatbot, conversation_id, media_content, med log_counter("save_chat_history_to_db_attempt") start_time = time.time() logging.info(f"Attempting to save chat history. Media content type: {type(media_content)}") + try: - # Extract the media_id and media_name from the media_content - media_id = None - if isinstance(media_content, dict): + # First check if we can access the database + try: + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT 1") + except sqlite3.DatabaseError as db_error: + logging.error(f"Database is corrupted or inaccessible: {str(db_error)}") + return conversation_id, "Database error: The database file appears to be corrupted. Please contact support." + + # Now attempt the save + if not conversation_id: + # Only for new conversations, not updates media_id = None - logging.debug(f"Media content keys: {media_content.keys()}") - if 'content' in media_content: + if isinstance(media_content, dict) and 'content' in media_content: try: content = media_content['content'] - if isinstance(content, str): - content_json = json.loads(content) - elif isinstance(content, dict): - content_json = content - else: - raise ValueError(f"Unexpected content type: {type(content)}") - - # Use the webpage_url as the media_id + content_json = content if isinstance(content, dict) else json.loads(content) media_id = content_json.get('webpage_url') - # Use the title as the media_name - media_name = content_json.get('title') - - logging.info(f"Extracted media_id: {media_id}, media_name: {media_name}") - except json.JSONDecodeError: - logging.error("Failed to decode JSON from media_content['content']") - except Exception as e: - logging.error(f"Error processing media_content: {str(e)}") + media_name = media_name or content_json.get('title', 'Unnamed Media') + except (json.JSONDecodeError, AttributeError) as e: + logging.error(f"Error processing media content: {str(e)}") + media_id = "unknown_media" + media_name = media_name or "Unnamed Media" else: - logging.warning("'content' key not found in media_content") - else: - logging.warning(f"media_content is not a dictionary. Type: {type(media_content)}") - - if media_id is None: - # If we couldn't find a media_id, we'll use a placeholder - media_id = "unknown_media" - logging.warning(f"Unable to extract media_id from media_content. Using placeholder: {media_id}") - - if media_name is None: - media_name = "Unnamed Media" - logging.warning(f"Unable to extract media_name from media_content. Using placeholder: {media_name}") + media_id = "unknown_media" + media_name = media_name or "Unnamed Media" + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + conversation_title = f"{media_name}_{timestamp}" + conversation_id = start_new_conversation(title=conversation_title, media_id=media_id) + logging.info(f"Created new conversation with ID: {conversation_id}") + + # For both new and existing conversations + try: + delete_messages_in_conversation(conversation_id) + for user_msg, assistant_msg in chatbot: + if user_msg: + save_message(conversation_id, "user", user_msg) + if assistant_msg: + save_message(conversation_id, "assistant", assistant_msg) + except sqlite3.DatabaseError as db_error: + logging.error(f"Database error during message save: {str(db_error)}") + return conversation_id, "Database error: Unable to save messages. Please try again or contact support." - # Generate a unique conversation name using media_id and current timestamp - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - conversation_name = f"{media_name}_{timestamp}" - - new_conversation_id = save_chat_history_to_database(chatbot, conversation_id, media_id, media_name, - conversation_name) save_duration = time.time() - start_time log_histogram("save_chat_history_to_db_duration", save_duration) log_counter("save_chat_history_to_db_success") - return new_conversation_id, f"Chat history saved successfully as {conversation_name}!" + + return conversation_id, "Chat history saved successfully!" + except Exception as e: log_counter("save_chat_history_to_db_error", labels={"error": str(e)}) error_message = f"Failed to save chat history: {str(e)}" diff --git a/App_Function_Libraries/Chat/__init__.py b/App_Function_Libraries/Chat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/App_Function_Libraries/Chunk_Lib.py b/App_Function_Libraries/Chunk_Lib.py index 2a37d3538..8d9ff2b60 100644 --- a/App_Function_Libraries/Chunk_Lib.py +++ b/App_Function_Libraries/Chunk_Lib.py @@ -106,6 +106,7 @@ def load_document(file_path: str) -> str: def improved_chunking_process(text: str, chunk_options: Dict[str, Any] = None) -> List[Dict[str, Any]]: logging.debug("Improved chunking process started...") + logging.debug(f"Received chunk_options: {chunk_options}") # Extract JSON metadata if present json_content = {} @@ -125,49 +126,70 @@ def improved_chunking_process(text: str, chunk_options: Dict[str, Any] = None) - text = text[len(header_text):].strip() logging.debug(f"Extracted header text: {header_text}") - options = chunk_options.copy() if chunk_options else {} + # Make a copy of chunk_options and ensure values are correct types + options = {} if chunk_options: - options.update(chunk_options) - - chunk_method = options.get('method', 'words') - max_size = options.get('max_size', 2000) - overlap = options.get('overlap', 0) - language = options.get('language', None) + try: + options['method'] = str(chunk_options.get('method', 'words')) + options['max_size'] = int(chunk_options.get('max_size', 2000)) + options['overlap'] = int(chunk_options.get('overlap', 0)) + # Handle language specially - it can be None + lang = chunk_options.get('language') + options['language'] = str(lang) if lang is not None else None + logging.debug(f"Processed options: {options}") + except Exception as e: + logging.error(f"Error processing chunk options: {e}") + raise + else: + options = {'method': 'words', 'max_size': 2000, 'overlap': 0, 'language': None} + logging.debug("Using default options") - if language is None: - language = detect_language(text) + if options.get('language') is None: + detected_lang = detect_language(text) + options['language'] = str(detected_lang) + logging.debug(f"Detected language: {options['language']}") - if chunk_method == 'json': - chunks = chunk_text_by_json(text, max_size=max_size, overlap=overlap) - else: - chunks = chunk_text(text, chunk_method, max_size, overlap, language) + try: + if options['method'] == 'json': + chunks = chunk_text_by_json(text, max_size=options['max_size'], overlap=options['overlap']) + else: + chunks = chunk_text(text, options['method'], options['max_size'], options['overlap'], options['language']) + logging.debug(f"Created {len(chunks)} chunks using method {options['method']}") + except Exception as e: + logging.error(f"Error in chunking process: {e}") + raise chunks_with_metadata = [] total_chunks = len(chunks) - for i, chunk in enumerate(chunks): - metadata = { - 'chunk_index': i + 1, - 'total_chunks': total_chunks, - 'chunk_method': chunk_method, - 'max_size': max_size, - 'overlap': overlap, - 'language': language, - 'relative_position': (i + 1) / total_chunks - } - metadata.update(json_content) # Add the extracted JSON content to metadata - metadata['header_text'] = header_text # Add the header text to metadata - - if chunk_method == 'json': - chunk_text_content = json.dumps(chunk['json'], ensure_ascii=False) - else: - chunk_text_content = chunk + try: + for i, chunk in enumerate(chunks): + metadata = { + 'chunk_index': i + 1, + 'total_chunks': total_chunks, + 'chunk_method': options['method'], + 'max_size': options['max_size'], + 'overlap': options['overlap'], + 'language': options['language'], + 'relative_position': float((i + 1) / total_chunks) + } + metadata.update(json_content) + metadata['header_text'] = header_text + + if options['method'] == 'json': + chunk_text_content = json.dumps(chunk['json'], ensure_ascii=False) + else: + chunk_text_content = chunk - chunks_with_metadata.append({ - 'text': chunk_text_content, - 'metadata': metadata - }) + chunks_with_metadata.append({ + 'text': chunk_text_content, + 'metadata': metadata + }) - return chunks_with_metadata + logging.debug(f"Successfully created metadata for all chunks") + return chunks_with_metadata + except Exception as e: + logging.error(f"Error creating chunk metadata: {e}") + raise def multi_level_chunking(text: str, method: str, max_size: int, overlap: int, language: str) -> List[str]: @@ -220,24 +242,35 @@ def determine_chunk_position(relative_position: float) -> str: def chunk_text_by_words(text: str, max_words: int = 300, overlap: int = 0, language: str = None) -> List[str]: logging.debug("chunk_text_by_words...") - if language is None: - language = detect_language(text) + logging.debug(f"Parameters: max_words={max_words}, overlap={overlap}, language={language}") - if language.startswith('zh'): # Chinese - import jieba - words = list(jieba.cut(text)) - elif language == 'ja': # Japanese - import fugashi - tagger = fugashi.Tagger() - words = [word.surface for word in tagger(text)] - else: # Default to simple splitting for other languages - words = text.split() + try: + if language is None: + language = detect_language(text) + logging.debug(f"Detected language: {language}") + + if language.startswith('zh'): # Chinese + import jieba + words = list(jieba.cut(text)) + elif language == 'ja': # Japanese + import fugashi + tagger = fugashi.Tagger() + words = [word.surface for word in tagger(text)] + else: # Default to simple splitting for other languages + words = text.split() + + logging.debug(f"Total words: {len(words)}") - chunks = [] - for i in range(0, len(words), max_words - overlap): - chunk = ' '.join(words[i:i + max_words]) - chunks.append(chunk) - return post_process_chunks(chunks) + chunks = [] + for i in range(0, len(words), max_words - overlap): + chunk = ' '.join(words[i:i + max_words]) + chunks.append(chunk) + logging.debug(f"Created chunk {len(chunks)} with {len(chunk.split())} words") + + return post_process_chunks(chunks) + except Exception as e: + logging.error(f"Error in chunk_text_by_words: {e}") + raise def chunk_text_by_sentences(text: str, max_sentences: int = 10, overlap: int = 0, language: str = None) -> List[str]: @@ -338,24 +371,24 @@ def get_chunk_metadata(chunk: str, full_text: str, chunk_type: str = "generic", """ chunk_length = len(chunk) start_index = full_text.find(chunk) - end_index = start_index + chunk_length if start_index != -1 else None + end_index = start_index + chunk_length if start_index != -1 else -1 # Calculate a hash for the chunk chunk_hash = hashlib.md5(chunk.encode()).hexdigest() metadata = { - 'start_index': start_index, - 'end_index': end_index, - 'word_count': len(chunk.split()), - 'char_count': chunk_length, + 'start_index': int(start_index), + 'end_index': int(end_index), + 'word_count': int(len(chunk.split())), + 'char_count': int(chunk_length), 'chunk_type': chunk_type, 'language': language, 'chunk_hash': chunk_hash, - 'relative_position': start_index / len(full_text) if len(full_text) > 0 and start_index != -1 else 0 + 'relative_position': float(start_index / len(full_text) if len(full_text) > 0 and start_index != -1 else 0) } if chunk_type == "chapter": - metadata['chapter_number'] = chapter_number + metadata['chapter_number'] = int(chapter_number) if chapter_number is not None else None metadata['chapter_pattern'] = chapter_pattern return metadata diff --git a/App_Function_Libraries/DB/Character_Chat_DB.py b/App_Function_Libraries/DB/Character_Chat_DB.py index 4050b3fb8..63bbac637 100644 --- a/App_Function_Libraries/DB/Character_Chat_DB.py +++ b/App_Function_Libraries/DB/Character_Chat_DB.py @@ -72,6 +72,54 @@ def initialize_database(): ); """) + # Create FTS5 virtual table for CharacterCards + cursor.execute(""" + CREATE VIRTUAL TABLE IF NOT EXISTS CharacterCards_fts USING fts5( + name, + description, + personality, + scenario, + system_prompt, + content='CharacterCards', + content_rowid='id' + ); + """) + + # Create triggers to keep FTS5 table in sync with CharacterCards + cursor.executescript(""" + CREATE TRIGGER IF NOT EXISTS CharacterCards_ai AFTER INSERT ON CharacterCards BEGIN + INSERT INTO CharacterCards_fts( + rowid, + name, + description, + personality, + scenario, + system_prompt + ) VALUES ( + new.id, + new.name, + new.description, + new.personality, + new.scenario, + new.system_prompt + ); + END; + + CREATE TRIGGER IF NOT EXISTS CharacterCards_ad AFTER DELETE ON CharacterCards BEGIN + DELETE FROM CharacterCards_fts WHERE rowid = old.id; + END; + + CREATE TRIGGER IF NOT EXISTS CharacterCards_au AFTER UPDATE ON CharacterCards BEGIN + UPDATE CharacterCards_fts SET + name = new.name, + description = new.description, + personality = new.personality, + scenario = new.scenario, + system_prompt = new.system_prompt + WHERE rowid = new.id; + END; + """) + # Create CharacterChats table cursor.execute(""" CREATE TABLE IF NOT EXISTS CharacterChats ( @@ -155,6 +203,7 @@ def setup_chat_database(): setup_chat_database() + ######################################################################################################## # # Character Card handling @@ -560,6 +609,7 @@ def delete_character_chat(chat_id: int) -> bool: finally: conn.close() + def fetch_keywords_for_chats(keywords: List[str]) -> List[int]: """ Fetch chat IDs associated with any of the specified keywords. @@ -589,6 +639,7 @@ def fetch_keywords_for_chats(keywords: List[str]) -> List[int]: finally: conn.close() + def save_chat_history_to_character_db(character_id: int, conversation_name: str, chat_history: List[Tuple[str, str]]) -> Optional[int]: """Save chat history to the CharacterChats table. @@ -596,9 +647,6 @@ def save_chat_history_to_character_db(character_id: int, conversation_name: str, """ return add_character_chat(character_id, conversation_name, chat_history) -def migrate_chat_to_media_db(): - pass - def search_db(query: str, fields: List[str], where_clause: str = "", page: int = 1, results_per_page: int = 5) -> List[Dict[str, Any]]: """ @@ -696,6 +744,316 @@ def fetch_all_chats() -> List[Dict[str, Any]]: logging.error(f"Error fetching all chats: {str(e)}") return [] + +def search_character_chat(query: str, fts_top_k: int = 10, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]: + """ + Perform a full-text search on the Character Chat database. + + Args: + query: Search query string. + fts_top_k: Maximum number of results to return. + relevant_media_ids: Optional list of character IDs to filter results. + + Returns: + List of search results with content and metadata. + """ + if not query.strip(): + return [] + + try: + # Construct a WHERE clause to limit the search to relevant character IDs + where_clause = "" + if relevant_media_ids: + placeholders = ','.join(['?'] * len(relevant_media_ids)) + where_clause = f"CharacterChats.character_id IN ({placeholders})" + + # Perform full-text search using existing search_db function + results = search_db(query, ["conversation_name", "chat_history"], where_clause, results_per_page=fts_top_k) + + # Format results + formatted_results = [] + for r in results: + formatted_results.append({ + "content": r['chat_history'], + "metadata": { + "chat_id": r['id'], + "conversation_name": r['conversation_name'], + "character_id": r['character_id'] + } + }) + + return formatted_results + + except Exception as e: + logging.error(f"Error in search_character_chat: {e}") + return [] + + +def search_character_cards(query: str, fts_top_k: int = 10, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]: + """ + Perform a full-text search on the Character Cards database. + + Args: + query: Search query string. + fts_top_k: Maximum number of results to return. + relevant_media_ids: Optional list of character IDs to filter results. + + Returns: + List of search results with content and metadata. + """ + if not query.strip(): + return [] + + try: + conn = sqlite3.connect(chat_DB_PATH) + cursor = conn.cursor() + + # Construct the query + sql_query = """ + SELECT CharacterCards.id, CharacterCards.name, CharacterCards.description, CharacterCards.personality, CharacterCards.scenario + FROM CharacterCards_fts + JOIN CharacterCards ON CharacterCards_fts.rowid = CharacterCards.id + WHERE CharacterCards_fts MATCH ? + """ + + params = [query] + + # Add filtering by character IDs if provided + if relevant_media_ids: + placeholders = ','.join(['?'] * len(relevant_media_ids)) + sql_query += f" AND CharacterCards.id IN ({placeholders})" + params.extend(relevant_media_ids) + + sql_query += " LIMIT ?" + params.append(fts_top_k) + + cursor.execute(sql_query, params) + rows = cursor.fetchall() + columns = [description[0] for description in cursor.description] + + results = [dict(zip(columns, row)) for row in rows] + + # Format results + formatted_results = [] + for r in results: + content = f"Name: {r['name']}\nDescription: {r['description']}\nPersonality: {r['personality']}\nScenario: {r['scenario']}" + formatted_results.append({ + "content": content, + "metadata": { + "character_id": r['id'], + "name": r['name'] + } + }) + + return formatted_results + + except Exception as e: + logging.error(f"Error in search_character_cards: {e}") + return [] + finally: + conn.close() + + +def fetch_character_ids_by_keywords(keywords: List[str]) -> List[int]: + """ + Fetch character IDs associated with any of the specified keywords. + + Args: + keywords (List[str]): List of keywords to search for. + + Returns: + List[int]: List of character IDs associated with the keywords. + """ + if not keywords: + return [] + + conn = sqlite3.connect(chat_DB_PATH) + cursor = conn.cursor() + try: + # Assuming 'tags' column in CharacterCards table stores tags as JSON array + placeholders = ','.join(['?'] * len(keywords)) + sql_query = f""" + SELECT DISTINCT id FROM CharacterCards + WHERE EXISTS ( + SELECT 1 FROM json_each(tags) + WHERE json_each.value IN ({placeholders}) + ) + """ + cursor.execute(sql_query, keywords) + rows = cursor.fetchall() + character_ids = [row[0] for row in rows] + return character_ids + except Exception as e: + logging.error(f"Error in fetch_character_ids_by_keywords: {e}") + return [] + finally: + conn.close() + + +################################################################### +# +# Character Keywords + +def view_char_keywords(): + try: + with sqlite3.connect(chat_DB_PATH) as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT DISTINCT keyword + FROM CharacterCards + CROSS JOIN json_each(tags) + WHERE json_valid(tags) + ORDER BY keyword + """) + keywords = cursor.fetchall() + if keywords: + keyword_list = [k[0] for k in keywords] + return "### Current Character Keywords:\n" + "\n".join( + [f"- {k}" for k in keyword_list]) + return "No keywords found." + except Exception as e: + return f"Error retrieving keywords: {str(e)}" + + +def add_char_keywords(name: str, keywords: str): + try: + keywords_list = [k.strip() for k in keywords.split(",") if k.strip()] + with sqlite3.connect('character_chat.db') as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT tags FROM CharacterCards WHERE name = ?", + (name,) + ) + result = cursor.fetchone() + if not result: + return "Character not found." + + current_tags = result[0] if result[0] else "[]" + current_keywords = set(current_tags[1:-1].split(',')) if current_tags != "[]" else set() + updated_keywords = current_keywords.union(set(keywords_list)) + + cursor.execute( + "UPDATE CharacterCards SET tags = ? WHERE name = ?", + (str(list(updated_keywords)), name) + ) + conn.commit() + return f"Successfully added keywords to character {name}" + except Exception as e: + return f"Error adding keywords: {str(e)}" + + +def delete_char_keyword(char_name: str, keyword: str) -> str: + """ + Delete a keyword from a character's tags. + + Args: + char_name (str): The name of the character + keyword (str): The keyword to delete + + Returns: + str: Success/failure message + """ + try: + with sqlite3.connect(chat_DB_PATH) as conn: + cursor = conn.cursor() + + # First, check if the character exists + cursor.execute("SELECT tags FROM CharacterCards WHERE name = ?", (char_name,)) + result = cursor.fetchone() + + if not result: + return f"Character '{char_name}' not found." + + # Parse existing tags + current_tags = json.loads(result[0]) if result[0] else [] + + if keyword not in current_tags: + return f"Keyword '{keyword}' not found in character '{char_name}' tags." + + # Remove the keyword + updated_tags = [tag for tag in current_tags if tag != keyword] + + # Update the character's tags + cursor.execute( + "UPDATE CharacterCards SET tags = ? WHERE name = ?", + (json.dumps(updated_tags), char_name) + ) + conn.commit() + + logging.info(f"Keyword '{keyword}' deleted from character '{char_name}'") + return f"Successfully deleted keyword '{keyword}' from character '{char_name}'." + + except Exception as e: + error_msg = f"Error deleting keyword: {str(e)}" + logging.error(error_msg) + return error_msg + + +def export_char_keywords_to_csv() -> Tuple[str, str]: + """ + Export all character keywords to a CSV file with associated metadata. + + Returns: + Tuple[str, str]: (status_message, file_path) + """ + import csv + from tempfile import NamedTemporaryFile + from datetime import datetime + + try: + # Create a temporary CSV file + temp_file = NamedTemporaryFile(mode='w+', delete=False, suffix='.csv', newline='') + + with sqlite3.connect(chat_DB_PATH) as conn: + cursor = conn.cursor() + + # Get all characters and their tags + cursor.execute(""" + SELECT + name, + tags, + (SELECT COUNT(*) FROM CharacterChats WHERE CharacterChats.character_id = CharacterCards.id) as chat_count + FROM CharacterCards + WHERE json_valid(tags) + ORDER BY name + """) + + results = cursor.fetchall() + + # Process the results to create rows for the CSV + csv_rows = [] + for name, tags_json, chat_count in results: + tags = json.loads(tags_json) if tags_json else [] + for tag in tags: + csv_rows.append([ + tag, # keyword + name, # character name + chat_count # number of chats + ]) + + # Write to CSV + writer = csv.writer(temp_file) + writer.writerow(['Keyword', 'Character Name', 'Number of Chats']) + writer.writerows(csv_rows) + + temp_file.close() + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + status_msg = f"Successfully exported {len(csv_rows)} character keyword entries to CSV." + logging.info(status_msg) + + return status_msg, temp_file.name + + except Exception as e: + error_msg = f"Error exporting keywords: {str(e)}" + logging.error(error_msg) + return error_msg, "" + +# +# End of Character chat keyword functions +###################################################### + + # # End of Character_Chat_DB.py ####################################################################################################################### diff --git a/App_Function_Libraries/DB/DB_Backups.py b/App_Function_Libraries/DB/DB_Backups.py new file mode 100644 index 000000000..e3b5b784d --- /dev/null +++ b/App_Function_Libraries/DB/DB_Backups.py @@ -0,0 +1,160 @@ +# Backup_Manager.py +# +# Imports: +import os +import shutil +import sqlite3 +from datetime import datetime +import logging +# +# Local Imports: +from App_Function_Libraries.DB.Character_Chat_DB import chat_DB_PATH +from App_Function_Libraries.DB.RAG_QA_Chat_DB import get_rag_qa_db_path +from App_Function_Libraries.Utils.Utils import get_project_relative_path +# +# End of Imports +####################################################################################################################### +# +# Functions: + +def init_backup_directory(backup_base_dir: str, db_name: str) -> str: + """Initialize backup directory for a specific database.""" + backup_dir = os.path.join(backup_base_dir, db_name) + os.makedirs(backup_dir, exist_ok=True) + return backup_dir + + +def create_backup(db_path: str, backup_dir: str, db_name: str) -> str: + """Create a full backup of the database.""" + try: + db_path = os.path.abspath(db_path) + backup_dir = os.path.abspath(backup_dir) + + logging.info(f"Creating backup:") + logging.info(f" DB Path: {db_path}") + logging.info(f" Backup Dir: {backup_dir}") + logging.info(f" DB Name: {db_name}") + + # Create subdirectory based on db_name + specific_backup_dir = os.path.join(backup_dir, db_name) + os.makedirs(specific_backup_dir, exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_file = os.path.join(specific_backup_dir, f"{db_name}_backup_{timestamp}.db") + logging.info(f" Full backup path: {backup_file}") + + # Create a backup using SQLite's backup API + with sqlite3.connect(db_path) as source, \ + sqlite3.connect(backup_file) as target: + source.backup(target) + + logging.info(f"Backup created successfully: {backup_file}") + return f"Backup created: {backup_file}" + except Exception as e: + error_msg = f"Failed to create backup: {str(e)}" + logging.error(error_msg) + return error_msg + + +def create_incremental_backup(db_path: str, backup_dir: str, db_name: str) -> str: + """Create an incremental backup using VACUUM INTO.""" + try: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_file = os.path.join(backup_dir, + f"{db_name}_incremental_{timestamp}.sqlib") + + with sqlite3.connect(db_path) as conn: + conn.execute(f"VACUUM INTO '{backup_file}'") + + logging.info(f"Incremental backup created: {backup_file}") + return f"Incremental backup created: {backup_file}" + except Exception as e: + error_msg = f"Failed to create incremental backup: {str(e)}" + logging.error(error_msg) + return error_msg + + +def list_backups(backup_dir: str) -> str: + """List all available backups.""" + try: + backups = [f for f in os.listdir(backup_dir) + if f.endswith(('.db', '.sqlib'))] + backups.sort(reverse=True) # Most recent first + return "\n".join(backups) if backups else "No backups found" + except Exception as e: + error_msg = f"Failed to list backups: {str(e)}" + logging.error(error_msg) + return error_msg + + +def restore_single_db_backup(db_path: str, backup_dir: str, db_name: str, backup_name: str) -> str: + """Restore database from a backup file.""" + try: + logging.info(f"Restoring backup: {backup_name}") + backup_path = os.path.join(backup_dir, backup_name) + if not os.path.exists(backup_path): + logging.error(f"Backup file not found: {backup_name}") + return f"Backup file not found: {backup_name}" + + # Create a timestamp for the current db + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + current_backup = os.path.join(backup_dir, + f"{db_name}_pre_restore_{timestamp}.db") + + # Backup current database before restore + logging.info(f"Creating backup of current database: {current_backup}") + shutil.copy2(db_path, current_backup) + + # Restore the backup + logging.info(f"Restoring database from {backup_name}") + shutil.copy2(backup_path, db_path) + + logging.info(f"Database restored from {backup_name}") + return f"Database restored from {backup_name}" + except Exception as e: + error_msg = f"Failed to restore backup: {str(e)}" + logging.error(error_msg) + return error_msg + + +def setup_backup_config(): + """Setup configuration for database backups.""" + backup_base_dir = get_project_relative_path('tldw_DB_Backups') + logging.info(f"Base backup directory: {os.path.abspath(backup_base_dir)}") + + # RAG Chat DB configuration + rag_db_path = get_rag_qa_db_path() + rag_backup_dir = os.path.join(backup_base_dir, 'rag_chat') + os.makedirs(rag_backup_dir, exist_ok=True) + logging.info(f"RAG backup directory: {os.path.abspath(rag_backup_dir)}") + + rag_db_config = { + 'db_path': rag_db_path, + 'backup_dir': rag_backup_dir, # Make sure we use the full path + 'db_name': 'rag_qa' + } + + # Character Chat DB configuration + char_backup_dir = os.path.join(backup_base_dir, 'character_chat') + os.makedirs(char_backup_dir, exist_ok=True) + logging.info(f"Character backup directory: {os.path.abspath(char_backup_dir)}") + + char_db_config = { + 'db_path': chat_DB_PATH, + 'backup_dir': char_backup_dir, # Make sure we use the full path + 'db_name': 'chatDB' + } + + # Media DB configuration (based on your logs) + media_backup_dir = os.path.join(backup_base_dir, 'media') + os.makedirs(media_backup_dir, exist_ok=True) + logging.info(f"Media backup directory: {os.path.abspath(media_backup_dir)}") + + media_db_config = { + 'db_path': os.path.join(os.path.dirname(chat_DB_PATH), 'media_summary.db'), + 'backup_dir': media_backup_dir, + 'db_name': 'media' + } + + return rag_db_config, char_db_config, media_db_config + diff --git a/App_Function_Libraries/DB/DB_Manager.py b/App_Function_Libraries/DB/DB_Manager.py index fc0c7b498..13c3cb15d 100644 --- a/App_Function_Libraries/DB/DB_Manager.py +++ b/App_Function_Libraries/DB/DB_Manager.py @@ -13,11 +13,14 @@ # # Import your existing SQLite functions from App_Function_Libraries.DB.SQLite_DB import DatabaseError +from App_Function_Libraries.DB.Prompts_DB import list_prompts as sqlite_list_prompts, \ + fetch_prompt_details as sqlite_fetch_prompt_details, add_prompt as sqlite_add_prompt, \ + search_prompts as sqlite_search_prompts, add_or_update_prompt as sqlite_add_or_update_prompt, \ + load_prompt_details as sqlite_load_prompt_details, insert_prompt_to_db as sqlite_insert_prompt_to_db, \ + delete_prompt as sqlite_delete_prompt from App_Function_Libraries.DB.SQLite_DB import ( update_media_content as sqlite_update_media_content, - list_prompts as sqlite_list_prompts, search_and_display as sqlite_search_and_display, - fetch_prompt_details as sqlite_fetch_prompt_details, keywords_browser_interface as sqlite_keywords_browser_interface, add_keyword as sqlite_add_keyword, delete_keyword as sqlite_delete_keyword, @@ -25,31 +28,17 @@ ingest_article_to_db as sqlite_ingest_article_to_db, add_media_to_database as sqlite_add_media_to_database, import_obsidian_note_to_db as sqlite_import_obsidian_note_to_db, - add_prompt as sqlite_add_prompt, - delete_chat_message as sqlite_delete_chat_message, - update_chat_message as sqlite_update_chat_message, - add_chat_message as sqlite_add_chat_message, - get_chat_messages as sqlite_get_chat_messages, - search_chat_conversations as sqlite_search_chat_conversations, - create_chat_conversation as sqlite_create_chat_conversation, - save_chat_history_to_database as sqlite_save_chat_history_to_database, view_database as sqlite_view_database, get_transcripts as sqlite_get_transcripts, get_trashed_items as sqlite_get_trashed_items, user_delete_item as sqlite_user_delete_item, empty_trash as sqlite_empty_trash, create_automated_backup as sqlite_create_automated_backup, - add_or_update_prompt as sqlite_add_or_update_prompt, - load_prompt_details as sqlite_load_prompt_details, - load_preset_prompts as sqlite_load_preset_prompts, - insert_prompt_to_db as sqlite_insert_prompt_to_db, - delete_prompt as sqlite_delete_prompt, search_and_display_items as sqlite_search_and_display_items, - get_conversation_name as sqlite_get_conversation_name, add_media_with_keywords as sqlite_add_media_with_keywords, check_media_and_whisper_model as sqlite_check_media_and_whisper_model, \ create_document_version as sqlite_create_document_version, - get_document_version as sqlite_get_document_version, sqlite_search_db, add_media_chunk as sqlite_add_media_chunk, + get_document_version as sqlite_get_document_version, search_media_db as sqlite_search_media_db, add_media_chunk as sqlite_add_media_chunk, sqlite_update_fts_for_media, get_unprocessed_media as sqlite_get_unprocessed_media, fetch_item_details as sqlite_fetch_item_details, \ search_media_database as sqlite_search_media_database, mark_as_trash as sqlite_mark_as_trash, \ get_media_transcripts as sqlite_get_media_transcripts, get_specific_transcript as sqlite_get_specific_transcript, \ @@ -60,23 +49,35 @@ delete_specific_prompt as sqlite_delete_specific_prompt, fetch_keywords_for_media as sqlite_fetch_keywords_for_media, \ update_keywords_for_media as sqlite_update_keywords_for_media, check_media_exists as sqlite_check_media_exists, \ - search_prompts as sqlite_search_prompts, get_media_content as sqlite_get_media_content, \ - get_paginated_files as sqlite_get_paginated_files, get_media_title as sqlite_get_media_title, \ - get_all_content_from_database as sqlite_get_all_content_from_database, - get_next_media_id as sqlite_get_next_media_id, \ - batch_insert_chunks as sqlite_batch_insert_chunks, Database, save_workflow_chat_to_db as sqlite_save_workflow_chat_to_db, \ - get_workflow_chat as sqlite_get_workflow_chat, update_media_content_with_version as sqlite_update_media_content_with_version, \ + get_media_content as sqlite_get_media_content, get_paginated_files as sqlite_get_paginated_files, \ + get_media_title as sqlite_get_media_title, get_all_content_from_database as sqlite_get_all_content_from_database, \ + get_next_media_id as sqlite_get_next_media_id, batch_insert_chunks as sqlite_batch_insert_chunks, Database, \ + save_workflow_chat_to_db as sqlite_save_workflow_chat_to_db, get_workflow_chat as sqlite_get_workflow_chat, \ + update_media_content_with_version as sqlite_update_media_content_with_version, \ check_existing_media as sqlite_check_existing_media, get_all_document_versions as sqlite_get_all_document_versions, \ fetch_paginated_data as sqlite_fetch_paginated_data, get_latest_transcription as sqlite_get_latest_transcription, \ mark_media_as_processed as sqlite_mark_media_as_processed, ) +from App_Function_Libraries.DB.RAG_QA_Chat_DB import start_new_conversation as sqlite_start_new_conversation, \ + save_message as sqlite_save_message, load_chat_history as sqlite_load_chat_history, \ + get_all_conversations as sqlite_get_all_conversations, get_notes_by_keywords as sqlite_get_notes_by_keywords, \ + get_note_by_id as sqlite_get_note_by_id, update_note as sqlite_update_note, save_notes as sqlite_save_notes, \ + clear_keywords_from_note as sqlite_clear_keywords_from_note, add_keywords_to_note as sqlite_add_keywords_to_note, \ + add_keywords_to_conversation as sqlite_add_keywords_to_conversation, \ + get_keywords_for_note as sqlite_get_keywords_for_note, delete_note as sqlite_delete_note, \ + search_conversations_by_keywords as sqlite_search_conversations_by_keywords, \ + delete_conversation as sqlite_delete_conversation, get_conversation_title as sqlite_get_conversation_title, \ + update_conversation_title as sqlite_update_conversation_title, \ + fetch_all_conversations as sqlite_fetch_all_conversations, fetch_all_notes as sqlite_fetch_all_notes, \ + fetch_conversations_by_ids as sqlite_fetch_conversations_by_ids, fetch_notes_by_ids as sqlite_fetch_notes_by_ids, \ + delete_messages_in_conversation as sqlite_delete_messages_in_conversation, \ + get_conversation_text as sqlite_get_conversation_text, search_notes_titles as sqlite_search_notes_titles from App_Function_Libraries.DB.Character_Chat_DB import ( add_character_card as sqlite_add_character_card, get_character_cards as sqlite_get_character_cards, \ get_character_card_by_id as sqlite_get_character_card_by_id, update_character_card as sqlite_update_character_card, \ delete_character_card as sqlite_delete_character_card, add_character_chat as sqlite_add_character_chat, \ get_character_chats as sqlite_get_character_chats, get_character_chat_by_id as sqlite_get_character_chat_by_id, \ - update_character_chat as sqlite_update_character_chat, delete_character_chat as sqlite_delete_character_chat, \ - migrate_chat_to_media_db as sqlite_migrate_chat_to_media_db, + update_character_chat as sqlite_update_character_chat, delete_character_chat as sqlite_delete_character_chat ) # # Local Imports @@ -214,9 +215,9 @@ def ensure_directory_exists(file_path): # # DB Search functions -def search_db(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 10): +def search_media_db(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 10): if db_type == 'sqlite': - return sqlite_search_db(search_query, search_fields, keywords, page, results_per_page) + return sqlite_search_media_db(search_query, search_fields, keywords, page, results_per_page) elif db_type == 'elasticsearch': # Implement Elasticsearch version when available raise NotImplementedError("Elasticsearch version of search_db not yet implemented") @@ -500,13 +501,6 @@ def load_prompt_details(*args, **kwargs): # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def load_preset_prompts(*args, **kwargs): - if db_type == 'sqlite': - return sqlite_load_preset_prompts() - elif db_type == 'elasticsearch': - # Implement Elasticsearch version - raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") - def insert_prompt_to_db(*args, **kwargs): if db_type == 'sqlite': return sqlite_insert_prompt_to_db(*args, **kwargs) @@ -539,7 +533,6 @@ def mark_as_trash(media_id: int) -> None: else: raise ValueError(f"Unsupported database type: {db_type}") - def get_latest_transcription(*args, **kwargs): if db_type == 'sqlite': return sqlite_get_latest_transcription(*args, **kwargs) @@ -721,62 +714,132 @@ def fetch_keywords_for_media(*args, **kwargs): # # Chat-related Functions -def delete_chat_message(*args, **kwargs): +def search_notes_titles(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_search_notes_titles(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def save_message(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_save_message(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def load_chat_history(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_load_chat_history(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def start_new_conversation(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_start_new_conversation(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def get_all_conversations(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_get_all_conversations(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def get_notes_by_keywords(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_get_notes_by_keywords(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def get_note_by_id(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_get_note_by_id(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def add_keywords_to_conversation(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_add_keywords_to_conversation(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def get_keywords_for_note(*args, **kwargs): if db_type == 'sqlite': - return sqlite_delete_chat_message(*args, **kwargs) + return sqlite_get_keywords_for_note(*args, **kwargs) elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def update_chat_message(*args, **kwargs): +def delete_note(*args, **kwargs): if db_type == 'sqlite': - return sqlite_update_chat_message(*args, **kwargs) + return sqlite_delete_note(*args, **kwargs) elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def add_chat_message(*args, **kwargs): +def search_conversations_by_keywords(*args, **kwargs): if db_type == 'sqlite': - return sqlite_add_chat_message(*args, **kwargs) + return sqlite_search_conversations_by_keywords(*args, **kwargs) elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def get_chat_messages(*args, **kwargs): +def delete_conversation(*args, **kwargs): if db_type == 'sqlite': - return sqlite_get_chat_messages(*args, **kwargs) + return sqlite_delete_conversation(*args, **kwargs) elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def search_chat_conversations(*args, **kwargs): +def get_conversation_title(*args, **kwargs): if db_type == 'sqlite': - return sqlite_search_chat_conversations(*args, **kwargs) + return sqlite_get_conversation_title(*args, **kwargs) elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def create_chat_conversation(*args, **kwargs): +def update_conversation_title(*args, **kwargs): if db_type == 'sqlite': - return sqlite_create_chat_conversation(*args, **kwargs) + return sqlite_update_conversation_title(*args, **kwargs) elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def save_chat_history_to_database(*args, **kwargs): +def fetch_all_conversations(*args, **kwargs): if db_type == 'sqlite': - return sqlite_save_chat_history_to_database(*args, **kwargs) + return sqlite_fetch_all_conversations() elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def get_conversation_name(*args, **kwargs): +def fetch_all_notes(*args, **kwargs): if db_type == 'sqlite': - return sqlite_get_conversation_name(*args, **kwargs) + return sqlite_fetch_all_notes() elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") +def delete_messages_in_conversation(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_delete_messages_in_conversation(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of delete_messages_in_conversation not yet implemented") + +def get_conversation_text(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_get_conversation_text(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of get_conversation_text not yet implemented") + # # End of Chat-related Functions ############################################################################################################ @@ -856,12 +919,54 @@ def delete_character_chat(*args, **kwargs): # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of delete_character_chat not yet implemented") -def migrate_chat_to_media_db(*args, **kwargs): +def update_note(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_update_note(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of update_note not yet implemented") + +def save_notes(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_save_notes(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of save_notes not yet implemented") + +def clear_keywords(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_clear_keywords_from_note(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of clear_keywords not yet implemented") + +def clear_keywords_from_note(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_clear_keywords_from_note(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of clear_keywords_from_note not yet implemented") + +def add_keywords_to_note(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_add_keywords_to_note(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_keywords_to_note not yet implemented") + +def fetch_conversations_by_ids(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_fetch_conversations_by_ids(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of fetch_conversations_by_ids not yet implemented") + +def fetch_notes_by_ids(*args, **kwargs): if db_type == 'sqlite': - return sqlite_migrate_chat_to_media_db(*args, **kwargs) + return sqlite_fetch_notes_by_ids(*args, **kwargs) elif db_type == 'elasticsearch': # Implement Elasticsearch version - raise NotImplementedError("Elasticsearch version of migrate_chat_to_media_db not yet implemented") + raise NotImplementedError("Elasticsearch version of fetch_notes_by_ids not yet implemented") # # End of Character Chat-related Functions diff --git a/App_Function_Libraries/DB/Prompts_DB.py b/App_Function_Libraries/DB/Prompts_DB.py new file mode 100644 index 000000000..68d8eae55 --- /dev/null +++ b/App_Function_Libraries/DB/Prompts_DB.py @@ -0,0 +1,626 @@ +# Prompts_DB.py +# Description: Functions to manage the prompts database. +# +# Imports +import sqlite3 +import logging +# +# External Imports +import re +from typing import Tuple +# +# Local Imports +from App_Function_Libraries.Utils.Utils import get_database_path +# +####################################################################################################################### +# +# Functions to manage prompts DB + +def create_prompts_db(): + logging.debug("create_prompts_db: Creating prompts database.") + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + cursor.executescript(''' + CREATE TABLE IF NOT EXISTS Prompts ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + author TEXT, + details TEXT, + system TEXT, + user TEXT + ); + CREATE TABLE IF NOT EXISTS Keywords ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + keyword TEXT NOT NULL UNIQUE COLLATE NOCASE + ); + CREATE TABLE IF NOT EXISTS PromptKeywords ( + prompt_id INTEGER, + keyword_id INTEGER, + FOREIGN KEY (prompt_id) REFERENCES Prompts (id), + FOREIGN KEY (keyword_id) REFERENCES Keywords (id), + PRIMARY KEY (prompt_id, keyword_id) + ); + CREATE INDEX IF NOT EXISTS idx_keywords_keyword ON Keywords(keyword); + CREATE INDEX IF NOT EXISTS idx_promptkeywords_prompt_id ON PromptKeywords(prompt_id); + CREATE INDEX IF NOT EXISTS idx_promptkeywords_keyword_id ON PromptKeywords(keyword_id); + ''') + +# FIXME - dirty hack that should be removed later... +# Migration function to add the 'author' column to the Prompts table +def add_author_column_to_prompts(): + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + # Check if 'author' column already exists + cursor.execute("PRAGMA table_info(Prompts)") + columns = [col[1] for col in cursor.fetchall()] + + if 'author' not in columns: + # Add the 'author' column + cursor.execute('ALTER TABLE Prompts ADD COLUMN author TEXT') + print("Author column added to Prompts table.") + else: + print("Author column already exists in Prompts table.") + +add_author_column_to_prompts() + +def normalize_keyword(keyword): + return re.sub(r'\s+', ' ', keyword.strip().lower()) + + +# FIXME - update calls to this function to use the new args +def add_prompt(name, author, details, system=None, user=None, keywords=None): + logging.debug(f"add_prompt: Adding prompt with name: {name}, author: {author}, system: {system}, user: {user}, keywords: {keywords}") + if not name: + logging.error("add_prompt: A name is required.") + return "A name is required." + + try: + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + cursor.execute(''' + INSERT INTO Prompts (name, author, details, system, user) + VALUES (?, ?, ?, ?, ?) + ''', (name, author, details, system, user)) + prompt_id = cursor.lastrowid + + if keywords: + normalized_keywords = [normalize_keyword(k) for k in keywords if k.strip()] + for keyword in set(normalized_keywords): # Use set to remove duplicates + cursor.execute(''' + INSERT OR IGNORE INTO Keywords (keyword) VALUES (?) + ''', (keyword,)) + cursor.execute('SELECT id FROM Keywords WHERE keyword = ?', (keyword,)) + keyword_id = cursor.fetchone()[0] + cursor.execute(''' + INSERT OR IGNORE INTO PromptKeywords (prompt_id, keyword_id) VALUES (?, ?) + ''', (prompt_id, keyword_id)) + return "Prompt added successfully." + except sqlite3.IntegrityError: + return "Prompt with this name already exists." + except sqlite3.Error as e: + return f"Database error: {e}" + + +def fetch_prompt_details(name): + logging.debug(f"fetch_prompt_details: Fetching details for prompt: {name}") + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + cursor.execute(''' + SELECT p.name, p.author, p.details, p.system, p.user, GROUP_CONCAT(k.keyword, ', ') as keywords + FROM Prompts p + LEFT JOIN PromptKeywords pk ON p.id = pk.prompt_id + LEFT JOIN Keywords k ON pk.keyword_id = k.id + WHERE p.name = ? + GROUP BY p.id + ''', (name,)) + return cursor.fetchone() + + +def list_prompts(page=1, per_page=10): + logging.debug(f"list_prompts: Listing prompts for page {page} with {per_page} prompts per page.") + offset = (page - 1) * per_page + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + cursor.execute('SELECT name FROM Prompts LIMIT ? OFFSET ?', (per_page, offset)) + prompts = [row[0] for row in cursor.fetchall()] + + # Get total count of prompts + cursor.execute('SELECT COUNT(*) FROM Prompts') + total_count = cursor.fetchone()[0] + + total_pages = (total_count + per_page - 1) // per_page + return prompts, total_pages, page + + +def insert_prompt_to_db(title, author, description, system_prompt, user_prompt, keywords=None): + return add_prompt(title, author, description, system_prompt, user_prompt, keywords) + + +def get_prompt_db_connection(): + prompt_db_path = get_database_path('prompts.db') + return sqlite3.connect(prompt_db_path) + + +def search_prompts(query): + logging.debug(f"search_prompts: Searching prompts with query: {query}") + try: + with get_prompt_db_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT p.name, p.details, p.system, p.user, GROUP_CONCAT(k.keyword, ', ') as keywords + FROM Prompts p + LEFT JOIN PromptKeywords pk ON p.id = pk.prompt_id + LEFT JOIN Keywords k ON pk.keyword_id = k.id + WHERE p.name LIKE ? OR p.details LIKE ? OR p.system LIKE ? OR p.user LIKE ? OR k.keyword LIKE ? + GROUP BY p.id + ORDER BY p.name + """, (f'%{query}%', f'%{query}%', f'%{query}%', f'%{query}%', f'%{query}%')) + return cursor.fetchall() + except sqlite3.Error as e: + logging.error(f"Error searching prompts: {e}") + return [] + + +def search_prompts_by_keyword(keyword, page=1, per_page=10): + logging.debug(f"search_prompts_by_keyword: Searching prompts by keyword: {keyword}") + normalized_keyword = normalize_keyword(keyword) + offset = (page - 1) * per_page + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + cursor.execute(''' + SELECT DISTINCT p.name + FROM Prompts p + JOIN PromptKeywords pk ON p.id = pk.prompt_id + JOIN Keywords k ON pk.keyword_id = k.id + WHERE k.keyword LIKE ? + LIMIT ? OFFSET ? + ''', ('%' + normalized_keyword + '%', per_page, offset)) + prompts = [row[0] for row in cursor.fetchall()] + + # Get total count of matching prompts + cursor.execute(''' + SELECT COUNT(DISTINCT p.id) + FROM Prompts p + JOIN PromptKeywords pk ON p.id = pk.prompt_id + JOIN Keywords k ON pk.keyword_id = k.id + WHERE k.keyword LIKE ? + ''', ('%' + normalized_keyword + '%',)) + total_count = cursor.fetchone()[0] + + total_pages = (total_count + per_page - 1) // per_page + return prompts, total_pages, page + + +def update_prompt_keywords(prompt_name, new_keywords): + logging.debug(f"update_prompt_keywords: Updating keywords for prompt: {prompt_name}") + try: + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + + cursor.execute('SELECT id FROM Prompts WHERE name = ?', (prompt_name,)) + prompt_id = cursor.fetchone() + if not prompt_id: + return "Prompt not found." + prompt_id = prompt_id[0] + + cursor.execute('DELETE FROM PromptKeywords WHERE prompt_id = ?', (prompt_id,)) + + normalized_keywords = [normalize_keyword(k) for k in new_keywords if k.strip()] + for keyword in set(normalized_keywords): # Use set to remove duplicates + cursor.execute('INSERT OR IGNORE INTO Keywords (keyword) VALUES (?)', (keyword,)) + cursor.execute('SELECT id FROM Keywords WHERE keyword = ?', (keyword,)) + keyword_id = cursor.fetchone()[0] + cursor.execute('INSERT INTO PromptKeywords (prompt_id, keyword_id) VALUES (?, ?)', + (prompt_id, keyword_id)) + + # Remove unused keywords + cursor.execute(''' + DELETE FROM Keywords + WHERE id NOT IN (SELECT DISTINCT keyword_id FROM PromptKeywords) + ''') + return "Keywords updated successfully." + except sqlite3.Error as e: + return f"Database error: {e}" + + +def add_or_update_prompt(title, author, description, system_prompt, user_prompt, keywords=None): + logging.debug(f"add_or_update_prompt: Adding or updating prompt: {title}") + if not title: + return "Error: Title is required." + + existing_prompt = fetch_prompt_details(title) + if existing_prompt: + # Update existing prompt + result = update_prompt_in_db(title, author, description, system_prompt, user_prompt) + if "successfully" in result: + # Update keywords if the prompt update was successful + keyword_result = update_prompt_keywords(title, keywords or []) + result += f" {keyword_result}" + else: + # Insert new prompt + result = insert_prompt_to_db(title, author, description, system_prompt, user_prompt, keywords) + + return result + + +def load_prompt_details(selected_prompt): + logging.debug(f"load_prompt_details: Loading prompt details for {selected_prompt}") + if selected_prompt: + details = fetch_prompt_details(selected_prompt) + if details: + return details[0], details[1], details[2], details[3], details[4], details[5] + return "", "", "", "", "", "" + + +def update_prompt_in_db(title, author, description, system_prompt, user_prompt): + logging.debug(f"update_prompt_in_db: Updating prompt: {title}") + try: + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE Prompts SET author = ?, details = ?, system = ?, user = ? WHERE name = ?", + (author, description, system_prompt, user_prompt, title) + ) + if cursor.rowcount == 0: + return "No prompt found with the given title." + return "Prompt updated successfully!" + except sqlite3.Error as e: + return f"Error updating prompt: {e}" + + +def delete_prompt(prompt_id): + logging.debug(f"delete_prompt: Deleting prompt with ID: {prompt_id}") + try: + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + + # Delete associated keywords + cursor.execute("DELETE FROM PromptKeywords WHERE prompt_id = ?", (prompt_id,)) + + # Delete the prompt + cursor.execute("DELETE FROM Prompts WHERE id = ?", (prompt_id,)) + + if cursor.rowcount == 0: + return f"No prompt found with ID {prompt_id}" + else: + conn.commit() + return f"Prompt with ID {prompt_id} has been successfully deleted" + except sqlite3.Error as e: + return f"An error occurred: {e}" + + +def delete_prompt_keyword(keyword: str) -> str: + """ + Delete a keyword and its associations from the prompts database. + + Args: + keyword (str): The keyword to delete + + Returns: + str: Success/failure message + """ + logging.debug(f"delete_prompt_keyword: Deleting keyword: {keyword}") + try: + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + + # First normalize the keyword + normalized_keyword = normalize_keyword(keyword) + + # Get the keyword ID + cursor.execute("SELECT id FROM Keywords WHERE keyword = ?", (normalized_keyword,)) + result = cursor.fetchone() + + if not result: + return f"Keyword '{keyword}' not found." + + keyword_id = result[0] + + # Delete keyword associations from PromptKeywords + cursor.execute("DELETE FROM PromptKeywords WHERE keyword_id = ?", (keyword_id,)) + + # Delete the keyword itself + cursor.execute("DELETE FROM Keywords WHERE id = ?", (keyword_id,)) + + # Get the number of affected prompts + affected_prompts = cursor.rowcount + + conn.commit() + + logging.info(f"Keyword '{keyword}' deleted successfully") + return f"Successfully deleted keyword '{keyword}' and removed it from {affected_prompts} prompts." + + except sqlite3.Error as e: + error_msg = f"Database error deleting keyword: {str(e)}" + logging.error(error_msg) + return error_msg + except Exception as e: + error_msg = f"Error deleting keyword: {str(e)}" + logging.error(error_msg) + return error_msg + + +def export_prompt_keywords_to_csv() -> Tuple[str, str]: + """ + Export all prompt keywords to a CSV file with associated metadata. + + Returns: + Tuple[str, str]: (status_message, file_path) + """ + import csv + import tempfile + import os + from datetime import datetime + + logging.debug("export_prompt_keywords_to_csv: Starting export") + try: + # Create a temporary file with a specific name in the system's temp directory + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + temp_dir = tempfile.gettempdir() + file_path = os.path.join(temp_dir, f'prompt_keywords_export_{timestamp}.csv') + + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + + # Get keywords with related prompt information + query = ''' + SELECT + k.keyword, + GROUP_CONCAT(p.name, ' | ') as prompt_names, + COUNT(DISTINCT p.id) as num_prompts, + GROUP_CONCAT(DISTINCT p.author, ' | ') as authors + FROM Keywords k + LEFT JOIN PromptKeywords pk ON k.id = pk.keyword_id + LEFT JOIN Prompts p ON pk.prompt_id = p.id + GROUP BY k.id, k.keyword + ORDER BY k.keyword + ''' + + cursor.execute(query) + results = cursor.fetchall() + + # Write to CSV + with open(file_path, 'w', newline='', encoding='utf-8') as csvfile: + writer = csv.writer(csvfile) + writer.writerow([ + 'Keyword', + 'Associated Prompts', + 'Number of Prompts', + 'Authors' + ]) + + for row in results: + writer.writerow([ + row[0], # keyword + row[1] if row[1] else '', # prompt_names (may be None) + row[2], # num_prompts + row[3] if row[3] else '' # authors (may be None) + ]) + + status_msg = f"Successfully exported {len(results)} prompt keywords to CSV." + logging.info(status_msg) + + return status_msg, file_path + + except sqlite3.Error as e: + error_msg = f"Database error exporting keywords: {str(e)}" + logging.error(error_msg) + return error_msg, "None" + except Exception as e: + error_msg = f"Error exporting keywords: {str(e)}" + logging.error(error_msg) + return error_msg, "None" + + +def view_prompt_keywords() -> str: + """ + View all keywords currently in the prompts database. + + Returns: + str: Markdown formatted string of all keywords + """ + logging.debug("view_prompt_keywords: Retrieving all keywords") + try: + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT k.keyword, COUNT(DISTINCT pk.prompt_id) as prompt_count + FROM Keywords k + LEFT JOIN PromptKeywords pk ON k.id = pk.keyword_id + GROUP BY k.id, k.keyword + ORDER BY k.keyword + """) + + keywords = cursor.fetchall() + if keywords: + keyword_list = [f"- {k[0]} ({k[1]} prompts)" for k in keywords] + return "### Current Prompt Keywords:\n" + "\n".join(keyword_list) + return "No keywords found." + + except Exception as e: + error_msg = f"Error retrieving keywords: {str(e)}" + logging.error(error_msg) + return error_msg + + +def export_prompts( + export_format='csv', + filter_keywords=None, + include_system=True, + include_user=True, + include_details=True, + include_author=True, + include_keywords=True, + markdown_template=None +) -> Tuple[str, str]: + """ + Export prompts to CSV or Markdown with configurable options. + + Args: + export_format (str): 'csv' or 'markdown' + filter_keywords (List[str], optional): Keywords to filter prompts by + include_system (bool): Include system prompts in export + include_user (bool): Include user prompts in export + include_details (bool): Include prompt details/descriptions + include_author (bool): Include author information + include_keywords (bool): Include associated keywords + markdown_template (str, optional): Template for markdown export + + Returns: + Tuple[str, str]: (status_message, file_path) + """ + import csv + import tempfile + import os + import zipfile + from datetime import datetime + + try: + # Get prompts data + with get_prompt_db_connection() as conn: + cursor = conn.cursor() + + # Build query based on included fields + select_fields = ['p.name'] + if include_author: + select_fields.append('p.author') + if include_details: + select_fields.append('p.details') + if include_system: + select_fields.append('p.system') + if include_user: + select_fields.append('p.user') + + query = f""" + SELECT DISTINCT {', '.join(select_fields)} + FROM Prompts p + """ + + # Add keyword filtering if specified + if filter_keywords: + placeholders = ','.join(['?' for _ in filter_keywords]) + query += f""" + JOIN PromptKeywords pk ON p.id = pk.prompt_id + JOIN Keywords k ON pk.keyword_id = k.id + WHERE k.keyword IN ({placeholders}) + """ + + cursor.execute(query, filter_keywords if filter_keywords else ()) + prompts = cursor.fetchall() + + # Get keywords for each prompt if needed + if include_keywords: + prompt_keywords = {} + for prompt in prompts: + cursor.execute(""" + SELECT k.keyword + FROM Keywords k + JOIN PromptKeywords pk ON k.id = pk.keyword_id + JOIN Prompts p ON pk.prompt_id = p.id + WHERE p.name = ? + """, (prompt[0],)) + prompt_keywords[prompt[0]] = [row[0] for row in cursor.fetchall()] + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + if export_format == 'csv': + # Export as CSV + temp_file = os.path.join(tempfile.gettempdir(), f'prompts_export_{timestamp}.csv') + with open(temp_file, 'w', newline='', encoding='utf-8') as csvfile: + writer = csv.writer(csvfile) + + # Write header + header = ['Name'] + if include_author: + header.append('Author') + if include_details: + header.append('Details') + if include_system: + header.append('System Prompt') + if include_user: + header.append('User Prompt') + if include_keywords: + header.append('Keywords') + writer.writerow(header) + + # Write data + for prompt in prompts: + row = list(prompt) + if include_keywords: + row.append(', '.join(prompt_keywords.get(prompt[0], []))) + writer.writerow(row) + + return f"Successfully exported {len(prompts)} prompts to CSV.", temp_file + + else: + # Export as Markdown files in ZIP + temp_dir = tempfile.mkdtemp() + zip_path = os.path.join(tempfile.gettempdir(), f'prompts_export_{timestamp}.zip') + + # Define markdown templates + templates = { + "Basic Template": """# {title} +{author_section} +{details_section} +{system_section} +{user_section} +{keywords_section} +""", + "Detailed Template": """# {title} + +## Author +{author_section} + +## Description +{details_section} + +## System Prompt +{system_section} + +## User Prompt +{user_section} + +## Keywords +{keywords_section} +""" + } + + template = templates.get(markdown_template, markdown_template or templates["Basic Template"]) + + with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: + for prompt in prompts: + # Create markdown content + md_content = template.format( + title=prompt[0], + author_section=f"Author: {prompt[1]}" if include_author else "", + details_section=prompt[2] if include_details else "", + system_section=prompt[3] if include_system else "", + user_section=prompt[4] if include_user else "", + keywords_section=', '.join(prompt_keywords.get(prompt[0], [])) if include_keywords else "" + ) + + # Create safe filename + safe_filename = re.sub(r'[^\w\-_\. ]', '_', prompt[0]) + md_path = os.path.join(temp_dir, f"{safe_filename}.md") + + # Write markdown file + with open(md_path, 'w', encoding='utf-8') as f: + f.write(md_content) + + # Add to ZIP + zipf.write(md_path, os.path.basename(md_path)) + + return f"Successfully exported {len(prompts)} prompts to Markdown files.", zip_path + + except Exception as e: + error_msg = f"Error exporting prompts: {str(e)}" + logging.error(error_msg) + return error_msg, "None" + + +create_prompts_db() + +# +# End of Propmts_DB.py +####################################################################################################################### + diff --git a/App_Function_Libraries/DB/RAG_QA_Chat_DB.py b/App_Function_Libraries/DB/RAG_QA_Chat_DB.py index f4df79894..35db2b3ba 100644 --- a/App_Function_Libraries/DB/RAG_QA_Chat_DB.py +++ b/App_Function_Libraries/DB/RAG_QA_Chat_DB.py @@ -4,39 +4,37 @@ # Imports import configparser import logging +import os import re import sqlite3 import uuid from contextlib import contextmanager from datetime import datetime - -from App_Function_Libraries.Utils.Utils import get_project_relative_path, get_database_path - +from pathlib import Path +from typing import List, Dict, Any, Tuple, Optional # # External Imports # (No external imports) # # Local Imports -# (No additional local imports) +from App_Function_Libraries.Utils.Utils import get_project_relative_path, get_project_root + # ######################################################################################################################## # # Functions: -# Construct the path to the config file -config_path = get_project_relative_path('Config_Files/config.txt') - -# Read the config file -config = configparser.ConfigParser() -config.read(config_path) - -# Get the SQLite path from the config, or use the default if not specified -if config.has_section('Database') and config.has_option('Database', 'rag_qa_db_path'): - rag_qa_db_path = config.get('Database', 'rag_qa_db_path') -else: - rag_qa_db_path = get_database_path('RAG_QA_Chat.db') - -print(f"RAG QA Chat Database path: {rag_qa_db_path}") +def get_rag_qa_db_path(): + config_path = os.path.join(get_project_root(), 'Config_Files', 'config.txt') + config = configparser.ConfigParser() + config.read(config_path) + if config.has_section('Database') and config.has_option('Database', 'rag_qa_db_path'): + rag_qa_db_path = config.get('Database', 'rag_qa_db_path') + if not os.path.isabs(rag_qa_db_path): + rag_qa_db_path = get_project_relative_path(rag_qa_db_path) + return rag_qa_db_path + else: + raise ValueError("Database path not found in config file") # Set up logging logging.basicConfig(level=logging.INFO) @@ -58,7 +56,9 @@ conversation_id TEXT PRIMARY KEY, created_at DATETIME NOT NULL, last_updated DATETIME NOT NULL, - title TEXT NOT NULL + title TEXT NOT NULL, + media_id INTEGER, + rating INTEGER CHECK(rating BETWEEN 1 AND 3) ); -- Table for storing keywords @@ -122,19 +122,137 @@ CREATE INDEX IF NOT EXISTS idx_rag_qa_collection_keywords_collection_id ON rag_qa_collection_keywords(collection_id); CREATE INDEX IF NOT EXISTS idx_rag_qa_collection_keywords_keyword_id ON rag_qa_collection_keywords(keyword_id); --- Full-text search virtual table for chat content -CREATE VIRTUAL TABLE IF NOT EXISTS rag_qa_chats_fts USING fts5(conversation_id, timestamp, role, content); +-- Full-text search virtual tables +CREATE VIRTUAL TABLE IF NOT EXISTS rag_qa_chats_fts USING fts5( + content, + content='rag_qa_chats', + content_rowid='id' +); + +-- FTS table for conversation metadata +CREATE VIRTUAL TABLE IF NOT EXISTS conversation_metadata_fts USING fts5( + title, + content='conversation_metadata', + content_rowid='rowid' +); + +-- FTS table for keywords +CREATE VIRTUAL TABLE IF NOT EXISTS rag_qa_keywords_fts USING fts5( + keyword, + content='rag_qa_keywords', + content_rowid='id' +); + +-- FTS table for keyword collections +CREATE VIRTUAL TABLE IF NOT EXISTS rag_qa_keyword_collections_fts USING fts5( + name, + content='rag_qa_keyword_collections', + content_rowid='id' +); + +-- FTS table for notes +CREATE VIRTUAL TABLE IF NOT EXISTS rag_qa_notes_fts USING fts5( + title, + content, + content='rag_qa_notes', + content_rowid='id' +); +-- FTS table for notes (modified to include both title and content) +CREATE VIRTUAL TABLE IF NOT EXISTS rag_qa_notes_fts USING fts5( + title, + content, + content='rag_qa_notes', + content_rowid='id' +); --- Trigger to keep the FTS table up to date +-- Triggers for maintaining FTS indexes +-- Triggers for rag_qa_chats CREATE TRIGGER IF NOT EXISTS rag_qa_chats_ai AFTER INSERT ON rag_qa_chats BEGIN - INSERT INTO rag_qa_chats_fts(conversation_id, timestamp, role, content) VALUES (new.conversation_id, new.timestamp, new.role, new.content); + INSERT INTO rag_qa_chats_fts(rowid, content) + VALUES (new.id, new.content); +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_chats_au AFTER UPDATE ON rag_qa_chats BEGIN + UPDATE rag_qa_chats_fts + SET content = new.content + WHERE rowid = old.id; +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_chats_ad AFTER DELETE ON rag_qa_chats BEGIN + DELETE FROM rag_qa_chats_fts WHERE rowid = old.id; +END; + +-- Triggers for conversation_metadata +CREATE TRIGGER IF NOT EXISTS conversation_metadata_ai AFTER INSERT ON conversation_metadata BEGIN + INSERT INTO conversation_metadata_fts(rowid, title) + VALUES (new.rowid, new.title); +END; + +CREATE TRIGGER IF NOT EXISTS conversation_metadata_au AFTER UPDATE ON conversation_metadata BEGIN + UPDATE conversation_metadata_fts + SET title = new.title + WHERE rowid = old.rowid; +END; + +CREATE TRIGGER IF NOT EXISTS conversation_metadata_ad AFTER DELETE ON conversation_metadata BEGIN + DELETE FROM conversation_metadata_fts WHERE rowid = old.rowid; +END; + +-- Triggers for rag_qa_keywords +CREATE TRIGGER IF NOT EXISTS rag_qa_keywords_ai AFTER INSERT ON rag_qa_keywords BEGIN + INSERT INTO rag_qa_keywords_fts(rowid, keyword) + VALUES (new.id, new.keyword); +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_keywords_au AFTER UPDATE ON rag_qa_keywords BEGIN + UPDATE rag_qa_keywords_fts + SET keyword = new.keyword + WHERE rowid = old.id; +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_keywords_ad AFTER DELETE ON rag_qa_keywords BEGIN + DELETE FROM rag_qa_keywords_fts WHERE rowid = old.id; +END; + +-- Triggers for rag_qa_keyword_collections +CREATE TRIGGER IF NOT EXISTS rag_qa_keyword_collections_ai AFTER INSERT ON rag_qa_keyword_collections BEGIN + INSERT INTO rag_qa_keyword_collections_fts(rowid, name) + VALUES (new.id, new.name); +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_keyword_collections_au AFTER UPDATE ON rag_qa_keyword_collections BEGIN + UPDATE rag_qa_keyword_collections_fts + SET name = new.name + WHERE rowid = old.id; +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_keyword_collections_ad AFTER DELETE ON rag_qa_keyword_collections BEGIN + DELETE FROM rag_qa_keyword_collections_fts WHERE rowid = old.id; +END; + +-- Triggers for rag_qa_notes +CREATE TRIGGER IF NOT EXISTS rag_qa_notes_ai AFTER INSERT ON rag_qa_notes BEGIN + INSERT INTO rag_qa_notes_fts(rowid, title, content) + VALUES (new.id, new.title, new.content); +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_notes_au AFTER UPDATE ON rag_qa_notes BEGIN + UPDATE rag_qa_notes_fts + SET title = new.title, + content = new.content + WHERE rowid = old.id; +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_notes_ad AFTER DELETE ON rag_qa_notes BEGIN + DELETE FROM rag_qa_notes_fts WHERE rowid = old.id; END; ''' # Database connection management @contextmanager def get_db_connection(): - conn = sqlite3.connect(rag_qa_db_path) + db_path = get_rag_qa_db_path() + conn = sqlite3.connect(db_path) try: yield conn finally: @@ -168,10 +286,43 @@ def execute_query(query, params=None, conn=None): conn.commit() return cursor.fetchall() + def create_tables(): + """Create database tables and initialize FTS indexes.""" with get_db_connection() as conn: - conn.executescript(SCHEMA_SQL) - logger.info("All RAG QA Chat tables created successfully") + cursor = conn.cursor() + # Execute the SCHEMA_SQL to create tables and triggers + cursor.executescript(SCHEMA_SQL) + + # Check and populate all FTS tables + fts_tables = [ + ('rag_qa_notes_fts', 'rag_qa_notes', ['title', 'content']), + ('rag_qa_chats_fts', 'rag_qa_chats', ['content']), + ('conversation_metadata_fts', 'conversation_metadata', ['title']), + ('rag_qa_keywords_fts', 'rag_qa_keywords', ['keyword']), + ('rag_qa_keyword_collections_fts', 'rag_qa_keyword_collections', ['name']) + ] + + for fts_table, source_table, columns in fts_tables: + # Check if FTS table needs population + cursor.execute(f"SELECT COUNT(*) FROM {fts_table}") + fts_count = cursor.fetchone()[0] + cursor.execute(f"SELECT COUNT(*) FROM {source_table}") + source_count = cursor.fetchone()[0] + + if fts_count != source_count: + # Repopulate FTS table + logger.info(f"Repopulating {fts_table}") + cursor.execute(f"DELETE FROM {fts_table}") + columns_str = ', '.join(columns) + source_columns = ', '.join([f"id" if source_table != 'conversation_metadata' else "rowid"] + columns) + cursor.execute(f""" + INSERT INTO {fts_table}(rowid, {columns_str}) + SELECT {source_columns} FROM {source_table} + """) + + logger.info("All RAG QA Chat tables and triggers created successfully") + # Initialize the database create_tables() @@ -197,6 +348,7 @@ def validate_keyword(keyword): raise ValueError("Keyword contains invalid characters") return keyword.strip() + def validate_collection_name(name): if not isinstance(name, str): raise ValueError("Collection name must be a string") @@ -208,6 +360,7 @@ def validate_collection_name(name): raise ValueError("Collection name contains invalid characters") return name.strip() + # Core functions def add_keyword(keyword, conn=None): try: @@ -222,6 +375,7 @@ def add_keyword(keyword, conn=None): logger.error(f"Error adding keyword '{keyword}': {e}") raise + def create_keyword_collection(name, parent_id=None): try: validated_name = validate_collection_name(name) @@ -235,6 +389,7 @@ def create_keyword_collection(name, parent_id=None): logger.error(f"Error creating keyword collection '{name}': {e}") raise + def add_keyword_to_collection(collection_name, keyword): try: validated_collection_name = validate_collection_name(collection_name) @@ -259,6 +414,7 @@ def add_keyword_to_collection(collection_name, keyword): logger.error(f"Error adding keyword '{keyword}' to collection '{collection_name}': {e}") raise + def add_keywords_to_conversation(conversation_id, keywords): if not isinstance(keywords, (list, tuple)): raise ValueError("Keywords must be a list or tuple") @@ -282,6 +438,23 @@ def add_keywords_to_conversation(conversation_id, keywords): logger.error(f"Error adding keywords to conversation '{conversation_id}': {e}") raise + +def view_rag_keywords(): + try: + rag_db_path = get_rag_qa_db_path() + with sqlite3.connect(rag_db_path) as conn: + cursor = conn.cursor() + cursor.execute("SELECT keyword FROM rag_qa_keywords ORDER BY keyword") + keywords = cursor.fetchall() + if keywords: + keyword_list = [k[0] for k in keywords] + return "### Current RAG QA Keywords:\n" + "\n".join( + [f"- {k}" for k in keyword_list]) + return "No keywords found." + except Exception as e: + return f"Error retrieving keywords: {str(e)}" + + def get_keywords_for_conversation(conversation_id): try: query = ''' @@ -298,6 +471,7 @@ def get_keywords_for_conversation(conversation_id): logger.error(f"Error getting keywords for conversation '{conversation_id}': {e}") raise + def get_keywords_for_collection(collection_name): try: query = ''' @@ -315,6 +489,116 @@ def get_keywords_for_collection(collection_name): logger.error(f"Error getting keywords for collection '{collection_name}': {e}") raise + +def delete_rag_keyword(keyword: str) -> str: + """ + Delete a keyword from the RAG QA database and all its associations. + + Args: + keyword (str): The keyword to delete + + Returns: + str: Success/failure message + """ + try: + # Validate the keyword + validated_keyword = validate_keyword(keyword) + + with transaction() as conn: + # First, get the keyword ID + cursor = conn.cursor() + cursor.execute("SELECT id FROM rag_qa_keywords WHERE keyword = ?", (validated_keyword,)) + result = cursor.fetchone() + + if not result: + return f"Keyword '{validated_keyword}' not found." + + keyword_id = result[0] + + # Delete from all associated tables + cursor.execute("DELETE FROM rag_qa_conversation_keywords WHERE keyword_id = ?", (keyword_id,)) + cursor.execute("DELETE FROM rag_qa_collection_keywords WHERE keyword_id = ?", (keyword_id,)) + cursor.execute("DELETE FROM rag_qa_note_keywords WHERE keyword_id = ?", (keyword_id,)) + + # Finally, delete the keyword itself + cursor.execute("DELETE FROM rag_qa_keywords WHERE id = ?", (keyword_id,)) + + logger.info(f"Keyword '{validated_keyword}' deleted successfully") + return f"Successfully deleted keyword '{validated_keyword}' and all its associations." + + except ValueError as e: + error_msg = f"Invalid keyword: {str(e)}" + logger.error(error_msg) + return error_msg + except Exception as e: + error_msg = f"Error deleting keyword: {str(e)}" + logger.error(error_msg) + return error_msg + + +def export_rag_keywords_to_csv() -> Tuple[str, str]: + """ + Export all RAG QA keywords to a CSV file. + + Returns: + Tuple[str, str]: (status_message, file_path) + """ + import csv + from tempfile import NamedTemporaryFile + from datetime import datetime + + try: + # Create a temporary CSV file + temp_file = NamedTemporaryFile(mode='w+', delete=False, suffix='.csv', newline='') + + with transaction() as conn: + cursor = conn.cursor() + + # Get all keywords and their associations + query = """ + SELECT + k.keyword, + GROUP_CONCAT(DISTINCT c.name) as collections, + COUNT(DISTINCT ck.conversation_id) as num_conversations, + COUNT(DISTINCT nk.note_id) as num_notes + FROM rag_qa_keywords k + LEFT JOIN rag_qa_collection_keywords col_k ON k.id = col_k.keyword_id + LEFT JOIN rag_qa_keyword_collections c ON col_k.collection_id = c.id + LEFT JOIN rag_qa_conversation_keywords ck ON k.id = ck.keyword_id + LEFT JOIN rag_qa_note_keywords nk ON k.id = nk.keyword_id + GROUP BY k.id, k.keyword + ORDER BY k.keyword + """ + + cursor.execute(query) + results = cursor.fetchall() + + # Write to CSV + writer = csv.writer(temp_file) + writer.writerow(['Keyword', 'Collections', 'Number of Conversations', 'Number of Notes']) + + for row in results: + writer.writerow([ + row[0], # keyword + row[1] if row[1] else '', # collections (may be None) + row[2], # num_conversations + row[3] # num_notes + ]) + + temp_file.close() + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + status_msg = f"Successfully exported {len(results)} keywords to CSV." + logger.info(status_msg) + + return status_msg, temp_file.name + + except Exception as e: + error_msg = f"Error exporting keywords: {str(e)}" + logger.error(error_msg) + return error_msg, "" + + # # End of Keyword-related functions ################################################### @@ -339,6 +623,7 @@ def save_notes(conversation_id, title, content): logger.error(f"Error saving notes for conversation '{conversation_id}': {e}") raise + def update_note(note_id, title, content): try: query = "UPDATE rag_qa_notes SET title = ?, content = ?, timestamp = ? WHERE id = ?" @@ -349,6 +634,121 @@ def update_note(note_id, title, content): logger.error(f"Error updating note ID '{note_id}': {e}") raise + +def search_notes_titles(search_term: str, page: int = 1, results_per_page: int = 20, connection=None) -> Tuple[ + List[Tuple], int, int]: + """ + Search note titles using full-text search. Returns all notes if search_term is empty. + + Args: + search_term (str): The search term for note titles. If empty, returns all notes. + page (int, optional): Page number for pagination. Defaults to 1. + results_per_page (int, optional): Number of results per page. Defaults to 20. + connection (sqlite3.Connection, optional): Database connection. Uses new connection if not provided. + + Returns: + Tuple[List[Tuple], int, int]: Tuple containing: + - List of tuples: (note_id, title, content, timestamp, conversation_id) + - Total number of pages + - Total count of matching records + + Raises: + ValueError: If page number is less than 1 + sqlite3.Error: If there's a database error + """ + if page < 1: + raise ValueError("Page number must be 1 or greater.") + + offset = (page - 1) * results_per_page + + def execute_search(conn): + cursor = conn.cursor() + + # Debug: Show table contents + cursor.execute("SELECT title FROM rag_qa_notes") + main_titles = cursor.fetchall() + logger.debug(f"Main table titles: {main_titles}") + + cursor.execute("SELECT title FROM rag_qa_notes_fts") + fts_titles = cursor.fetchall() + logger.debug(f"FTS table titles: {fts_titles}") + + if not search_term.strip(): + # Query for all notes + cursor.execute( + """ + SELECT COUNT(*) + FROM rag_qa_notes + """ + ) + total_count = cursor.fetchone()[0] + + cursor.execute( + """ + SELECT id, title, content, timestamp, conversation_id + FROM rag_qa_notes + ORDER BY timestamp DESC + LIMIT ? OFFSET ? + """, + (results_per_page, offset) + ) + results = cursor.fetchall() + else: + # Search query + search_term_clean = search_term.strip().lower() + + # Test direct FTS search + cursor.execute( + """ + SELECT COUNT(*) + FROM rag_qa_notes n + JOIN rag_qa_notes_fts fts ON n.id = fts.rowid + WHERE fts.title MATCH ? + """, + (search_term_clean,) + ) + total_count = cursor.fetchone()[0] + + cursor.execute( + """ + SELECT + n.id, + n.title, + n.content, + n.timestamp, + n.conversation_id + FROM rag_qa_notes n + JOIN rag_qa_notes_fts fts ON n.id = fts.rowid + WHERE fts.title MATCH ? + ORDER BY rank + LIMIT ? OFFSET ? + """, + (search_term_clean, results_per_page, offset) + ) + results = cursor.fetchall() + + logger.debug(f"Search term: {search_term_clean}") + logger.debug(f"Results: {results}") + + total_pages = max(1, (total_count + results_per_page - 1) // results_per_page) + logger.info(f"Found {total_count} matching notes for search term '{search_term}'") + + return results, total_pages, total_count + + try: + if connection: + return execute_search(connection) + else: + with get_db_connection() as conn: + return execute_search(conn) + + except sqlite3.Error as e: + logger.error(f"Database error in search_notes_titles: {str(e)}") + logger.error(f"Search term: {search_term}") + raise sqlite3.Error(f"Error searching notes: {str(e)}") + + + def get_notes(conversation_id): """Retrieve notes for a given conversation.""" try: @@ -361,6 +761,7 @@ def get_notes(conversation_id): logger.error(f"Error getting notes for conversation '{conversation_id}': {e}") raise + def get_note_by_id(note_id): try: query = "SELECT id, title, content FROM rag_qa_notes WHERE id = ?" @@ -370,9 +771,21 @@ def get_note_by_id(note_id): logger.error(f"Error getting note by ID '{note_id}': {e}") raise + def get_notes_by_keywords(keywords, page=1, page_size=20): try: - placeholders = ','.join(['?'] * len(keywords)) + # Handle empty or invalid keywords + if not keywords or not isinstance(keywords, (list, tuple)) or len(keywords) == 0: + return [], 0, 0 + + # Convert all keywords to strings and strip them + clean_keywords = [str(k).strip() for k in keywords if k is not None and str(k).strip()] + + # If no valid keywords after cleaning, return empty result + if not clean_keywords: + return [], 0, 0 + + placeholders = ','.join(['?'] * len(clean_keywords)) query = f''' SELECT n.id, n.title, n.content, n.timestamp FROM rag_qa_notes n @@ -381,14 +794,15 @@ def get_notes_by_keywords(keywords, page=1, page_size=20): WHERE k.keyword IN ({placeholders}) ORDER BY n.timestamp DESC ''' - results, total_pages, total_count = get_paginated_results(query, tuple(keywords), page, page_size) - logger.info(f"Retrieved {len(results)} notes matching keywords: {', '.join(keywords)} (page {page} of {total_pages})") + results, total_pages, total_count = get_paginated_results(query, tuple(clean_keywords), page, page_size) + logger.info(f"Retrieved {len(results)} notes matching keywords: {', '.join(clean_keywords)} (page {page} of {total_pages})") notes = [(row[0], row[1], row[2], row[3]) for row in results] return notes, total_pages, total_count except Exception as e: logger.error(f"Error getting notes by keywords: {e}") raise + def get_notes_by_keyword_collection(collection_name, page=1, page_size=20): try: query = ''' @@ -501,9 +915,10 @@ def delete_note(note_id): # # Chat-related functions -def save_message(conversation_id, role, content): +def save_message(conversation_id, role, content, timestamp=None): try: - timestamp = datetime.now().isoformat() + if timestamp is None: + timestamp = datetime.now().isoformat() query = "INSERT INTO rag_qa_chats (conversation_id, timestamp, role, content) VALUES (?, ?, ?, ?)" execute_query(query, (conversation_id, timestamp, role, content)) @@ -516,29 +931,103 @@ def save_message(conversation_id, role, content): logger.error(f"Error saving message for conversation '{conversation_id}': {e}") raise -def start_new_conversation(title="Untitled Conversation"): + +def start_new_conversation(title="Untitled Conversation", media_id=None): try: conversation_id = str(uuid.uuid4()) - query = "INSERT INTO conversation_metadata (conversation_id, created_at, last_updated, title) VALUES (?, ?, ?, ?)" + query = """ + INSERT INTO conversation_metadata ( + conversation_id, created_at, last_updated, title, media_id, rating + ) VALUES (?, ?, ?, ?, ?, ?) + """ now = datetime.now().isoformat() - execute_query(query, (conversation_id, now, now, title)) - logger.info(f"New conversation '{conversation_id}' started with title '{title}'") + # Set initial rating to NULL + execute_query(query, (conversation_id, now, now, title, media_id, None)) + logger.info(f"New conversation '{conversation_id}' started with title '{title}' and media_id '{media_id}'") return conversation_id except Exception as e: logger.error(f"Error starting new conversation: {e}") raise + def get_all_conversations(page=1, page_size=20): try: - query = "SELECT conversation_id, title FROM conversation_metadata ORDER BY last_updated DESC" - results, total_pages, total_count = get_paginated_results(query, page=page, page_size=page_size) - conversations = [(row[0], row[1]) for row in results] - logger.info(f"Retrieved {len(conversations)} conversations (page {page} of {total_pages})") - return conversations, total_pages, total_count + query = """ + SELECT conversation_id, title, media_id, rating + FROM conversation_metadata + ORDER BY last_updated DESC + LIMIT ? OFFSET ? + """ + + count_query = "SELECT COUNT(*) FROM conversation_metadata" + db_path = get_rag_qa_db_path() + with sqlite3.connect(db_path) as conn: + cursor = conn.cursor() + + # Get total count + cursor.execute(count_query) + total_count = cursor.fetchone()[0] + total_pages = (total_count + page_size - 1) // page_size + + # Get page of results + offset = (page - 1) * page_size + cursor.execute(query, (page_size, offset)) + results = cursor.fetchall() + + conversations = [{ + 'conversation_id': row[0], + 'title': row[1], + 'media_id': row[2], + 'rating': row[3] # Include rating + } for row in results] + return conversations, total_pages, total_count except Exception as e: - logger.error(f"Error getting conversations: {e}") + logging.error(f"Error getting conversations: {e}") raise + +def get_all_notes(page=1, page_size=20): + try: + query = """ + SELECT n.id, n.conversation_id, n.title, n.content, n.timestamp, + cm.title as conversation_title, cm.media_id + FROM rag_qa_notes n + LEFT JOIN conversation_metadata cm ON n.conversation_id = cm.conversation_id + ORDER BY n.timestamp DESC + LIMIT ? OFFSET ? + """ + + count_query = "SELECT COUNT(*) FROM rag_qa_notes" + db_path = get_rag_qa_db_path() + with sqlite3.connect(db_path) as conn: + cursor = conn.cursor() + + # Get total count + cursor.execute(count_query) + total_count = cursor.fetchone()[0] + total_pages = (total_count + page_size - 1) // page_size + + # Get page of results + offset = (page - 1) * page_size + cursor.execute(query, (page_size, offset)) + results = cursor.fetchall() + + notes = [{ + 'id': row[0], + 'conversation_id': row[1], + 'title': row[2], + 'content': row[3], + 'timestamp': row[4], + 'conversation_title': row[5], + 'media_id': row[6] + } for row in results] + + return notes, total_pages, total_count + except Exception as e: + logging.error(f"Error getting notes: {e}") + raise + + # Pagination helper function def get_paginated_results(query, params=None, page=1, page_size=20): try: @@ -564,6 +1053,7 @@ def get_paginated_results(query, params=None, page=1, page_size=20): logger.error(f"Error retrieving paginated results: {e}") raise + def get_all_collections(page=1, page_size=20): try: query = "SELECT name FROM rag_qa_keyword_collections" @@ -575,24 +1065,79 @@ def get_all_collections(page=1, page_size=20): logger.error(f"Error getting collections: {e}") raise -def search_conversations_by_keywords(keywords, page=1, page_size=20): + +def search_conversations_by_keywords(keywords=None, title_query=None, content_query=None, page=1, page_size=20): try: - placeholders = ','.join(['?' for _ in keywords]) - query = f''' - SELECT DISTINCT cm.conversation_id, cm.title + # Base query starts with conversation metadata + query = """ + SELECT DISTINCT cm.conversation_id, cm.title, cm.last_updated FROM conversation_metadata cm - JOIN rag_qa_conversation_keywords ck ON cm.conversation_id = ck.conversation_id - JOIN rag_qa_keywords k ON ck.keyword_id = k.id - WHERE k.keyword IN ({placeholders}) - ''' - results, total_pages, total_count = get_paginated_results(query, tuple(keywords), page, page_size) - logger.info( - f"Found {total_count} conversations matching keywords: {', '.join(keywords)} (page {page} of {total_pages})") - return results, total_pages, total_count + WHERE 1=1 + """ + params = [] + + # Add content search if provided + if content_query and isinstance(content_query, str) and content_query.strip(): + query += """ + AND EXISTS ( + SELECT 1 FROM rag_qa_chats_fts + WHERE rag_qa_chats_fts.content MATCH ? + AND rag_qa_chats_fts.rowid IN ( + SELECT id FROM rag_qa_chats + WHERE conversation_id = cm.conversation_id + ) + ) + """ + params.append(content_query.strip()) + + # Add title search if provided + if title_query and isinstance(title_query, str) and title_query.strip(): + query += """ + AND EXISTS ( + SELECT 1 FROM conversation_metadata_fts + WHERE conversation_metadata_fts.title MATCH ? + AND conversation_metadata_fts.rowid = cm.rowid + ) + """ + params.append(title_query.strip()) + + # Add keyword search if provided + if keywords and isinstance(keywords, (list, tuple)) and len(keywords) > 0: + # Convert all keywords to strings and strip them + clean_keywords = [str(k).strip() for k in keywords if k is not None and str(k).strip()] + if clean_keywords: # Only add to query if we have valid keywords + placeholders = ','.join(['?' for _ in clean_keywords]) + query += f""" + AND EXISTS ( + SELECT 1 FROM rag_qa_conversation_keywords ck + JOIN rag_qa_keywords k ON ck.keyword_id = k.id + WHERE ck.conversation_id = cm.conversation_id + AND k.keyword IN ({placeholders}) + ) + """ + params.extend(clean_keywords) + + # Add ordering + query += " ORDER BY cm.last_updated DESC" + + results, total_pages, total_count = get_paginated_results(query, tuple(params), page, page_size) + + conversations = [ + { + 'conversation_id': row[0], + 'title': row[1], + 'last_updated': row[2] + } + for row in results + ] + + return conversations, total_pages, total_count + except Exception as e: - logger.error(f"Error searching conversations by keywords {keywords}: {e}") + logger.error(f"Error searching conversations: {e}") raise + def load_chat_history(conversation_id, page=1, page_size=50): try: query = "SELECT role, content FROM rag_qa_chats WHERE conversation_id = ? ORDER BY timestamp" @@ -604,6 +1149,7 @@ def load_chat_history(conversation_id, page=1, page_size=50): logger.error(f"Error loading chat history for conversation '{conversation_id}': {e}") raise + def update_conversation_title(conversation_id, new_title): """Update the title of a conversation.""" try: @@ -614,6 +1160,7 @@ def update_conversation_title(conversation_id, new_title): logger.error(f"Error updating conversation title: {e}") raise + def delete_messages_in_conversation(conversation_id): """Helper function to delete all messages in a conversation.""" try: @@ -623,6 +1170,7 @@ def delete_messages_in_conversation(conversation_id): logging.error(f"Error deleting messages in conversation '{conversation_id}': {e}") raise + def get_conversation_title(conversation_id): """Helper function to get the conversation title.""" query = "SELECT title FROM conversation_metadata WHERE conversation_id = ?" @@ -632,6 +1180,39 @@ def get_conversation_title(conversation_id): else: return "Untitled Conversation" + +def get_conversation_text(conversation_id): + try: + query = """ + SELECT role, content + FROM rag_qa_chats + WHERE conversation_id = ? + ORDER BY timestamp ASC + """ + + messages = [] + # Use the connection as a context manager + db_path = get_rag_qa_db_path() + with sqlite3.connect(db_path) as conn: + cursor = conn.cursor() + cursor.execute(query, (conversation_id,)) + messages = cursor.fetchall() + + return "\n\n".join([f"{msg[0]}: {msg[1]}" for msg in messages]) + except Exception as e: + logger.error(f"Error getting conversation text: {e}") + raise + + +def get_conversation_details(conversation_id): + query = "SELECT title, media_id, rating FROM conversation_metadata WHERE conversation_id = ?" + result = execute_query(query, (conversation_id,)) + if result: + return {'title': result[0][0], 'media_id': result[0][1], 'rating': result[0][2]} + else: + return {'title': "Untitled Conversation", 'media_id': None, 'rating': None} + + def delete_conversation(conversation_id): """Delete a conversation and its associated messages and notes.""" try: @@ -651,11 +1232,203 @@ def delete_conversation(conversation_id): logger.error(f"Error deleting conversation '{conversation_id}': {e}") raise +def set_conversation_rating(conversation_id, rating): + """Set the rating for a conversation.""" + # Validate rating + if rating not in [1, 2, 3]: + raise ValueError('Rating must be an integer between 1 and 3.') + try: + query = "UPDATE conversation_metadata SET rating = ? WHERE conversation_id = ?" + execute_query(query, (rating, conversation_id)) + logger.info(f"Rating for conversation '{conversation_id}' set to {rating}") + except Exception as e: + logger.error(f"Error setting rating for conversation '{conversation_id}': {e}") + raise + +def get_conversation_rating(conversation_id): + """Get the rating of a conversation.""" + try: + query = "SELECT rating FROM conversation_metadata WHERE conversation_id = ?" + result = execute_query(query, (conversation_id,)) + if result: + rating = result[0][0] + logger.info(f"Rating for conversation '{conversation_id}' is {rating}") + return rating + else: + logger.warning(f"Conversation '{conversation_id}' not found.") + return None + except Exception as e: + logger.error(f"Error getting rating for conversation '{conversation_id}': {e}") + raise + + +def get_conversation_name(conversation_id: str) -> str: + """ + Retrieves the title/name of a conversation from the conversation_metadata table. + + Args: + conversation_id (str): The unique identifier of the conversation + + Returns: + str: The title of the conversation if found, "Untitled Conversation" if not found + + Raises: + sqlite3.Error: If there's a database error + """ + try: + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT title FROM conversation_metadata WHERE conversation_id = ?", + (conversation_id,) + ) + result = cursor.fetchone() + + if result: + return result[0] + else: + logging.warning(f"No conversation found with ID: {conversation_id}") + return "Untitled Conversation" + + except sqlite3.Error as e: + logging.error(f"Database error retrieving conversation name for ID {conversation_id}: {e}") + raise + except Exception as e: + logging.error(f"Unexpected error retrieving conversation name for ID {conversation_id}: {e}") + raise + + +def search_rag_chat(query: str, fts_top_k: int = 10, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]: + """ + Perform a full-text search on the RAG Chat database. + + Args: + query: Search query string. + fts_top_k: Maximum number of results to return. + relevant_media_ids: Optional list of media IDs to filter results. + + Returns: + List of search results with content and metadata. + """ + if not query.strip(): + return [] + + try: + db_path = get_rag_qa_db_path() + with sqlite3.connect(db_path) as conn: + cursor = conn.cursor() + # Perform the full-text search using the FTS virtual table + cursor.execute(""" + SELECT rag_qa_chats.id, rag_qa_chats.conversation_id, rag_qa_chats.role, rag_qa_chats.content + FROM rag_qa_chats_fts + JOIN rag_qa_chats ON rag_qa_chats_fts.rowid = rag_qa_chats.id + WHERE rag_qa_chats_fts MATCH ? + LIMIT ? + """, (query, fts_top_k)) + + rows = cursor.fetchall() + columns = [description[0] for description in cursor.description] + results = [dict(zip(columns, row)) for row in rows] + + # Filter by relevant_media_ids if provided + if relevant_media_ids is not None: + results = [ + r for r in results + if get_conversation_details(r['conversation_id']).get('media_id') in relevant_media_ids + ] + + # Format results + formatted_results = [ + { + "content": r['content'], + "metadata": { + "conversation_id": r['conversation_id'], + "role": r['role'], + "media_id": get_conversation_details(r['conversation_id']).get('media_id') + } + } + for r in results + ] + return formatted_results + + except Exception as e: + logging.error(f"Error in search_rag_chat: {e}") + return [] + + +def search_rag_notes(query: str, fts_top_k: int = 10, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]: + """ + Perform a full-text search on the RAG Notes database. + + Args: + query: Search query string. + fts_top_k: Maximum number of results to return. + relevant_media_ids: Optional list of media IDs to filter results. + + Returns: + List of search results with content and metadata. + """ + if not query.strip(): + return [] + + try: + db_path = get_rag_qa_db_path() + with sqlite3.connect(db_path) as conn: + cursor = conn.cursor() + # Perform the full-text search using the FTS virtual table + cursor.execute(""" + SELECT rag_qa_notes.id, rag_qa_notes.title, rag_qa_notes.content, rag_qa_notes.conversation_id + FROM rag_qa_notes_fts + JOIN rag_qa_notes ON rag_qa_notes_fts.rowid = rag_qa_notes.id + WHERE rag_qa_notes_fts MATCH ? + LIMIT ? + """, (query, fts_top_k)) + + rows = cursor.fetchall() + columns = [description[0] for description in cursor.description] + results = [dict(zip(columns, row)) for row in rows] + + # Filter by relevant_media_ids if provided + if relevant_media_ids is not None: + results = [ + r for r in results + if get_conversation_details(r['conversation_id']).get('media_id') in relevant_media_ids + ] + + # Format results + formatted_results = [ + { + "content": r['content'], + "metadata": { + "note_id": r['id'], + "title": r['title'], + "conversation_id": r['conversation_id'], + "media_id": get_conversation_details(r['conversation_id']).get('media_id') + } + } + for r in results + ] + return formatted_results + + except Exception as e: + logging.error(f"Error in search_rag_notes: {e}") + return [] + # # End of Chat-related functions ################################################### +################################################### +# +# Import functions + + +# +# End of Import functions +################################################### + + ################################################### # # Functions to export DB data diff --git a/App_Function_Libraries/DB/SQLite_DB.py b/App_Function_Libraries/DB/SQLite_DB.py index 7ec3bf2e5..6c6c44024 100644 --- a/App_Function_Libraries/DB/SQLite_DB.py +++ b/App_Function_Libraries/DB/SQLite_DB.py @@ -21,7 +21,7 @@ # 11. browse_items(search_query, search_type) # 12. fetch_item_details(media_id: int) # 13. add_media_version(media_id: int, prompt: str, summary: str) -# 14. search_db(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 10) +# 14. search_media_db(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 10) # 15. search_and_display(search_query, search_fields, keywords, page) # 16. display_details(index, results) # 17. get_details(index, dataframe) @@ -55,12 +55,14 @@ import shutil import sqlite3 import threading +import time import traceback from contextlib import contextmanager from datetime import datetime, timedelta from typing import List, Tuple, Dict, Any, Optional from urllib.parse import quote +from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram # Local Libraries from App_Function_Libraries.Utils.Utils import get_project_relative_path, get_database_path, \ get_database_dir @@ -342,27 +344,6 @@ def create_tables(db) -> None: ) ''', ''' - CREATE TABLE IF NOT EXISTS ChatConversations ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - media_id INTEGER, - media_name TEXT, - conversation_name TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (media_id) REFERENCES Media(id) - ) - ''', - ''' - CREATE TABLE IF NOT EXISTS ChatMessages ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - conversation_id INTEGER, - sender TEXT, - message TEXT, - timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (conversation_id) REFERENCES ChatConversations(id) - ) - ''', - ''' CREATE TABLE IF NOT EXISTS Transcripts ( id INTEGER PRIMARY KEY AUTOINCREMENT, media_id INTEGER, @@ -421,8 +402,6 @@ def create_tables(db) -> None: 'CREATE INDEX IF NOT EXISTS idx_mediakeywords_keyword_id ON MediaKeywords(keyword_id)', 'CREATE INDEX IF NOT EXISTS idx_media_version_media_id ON MediaVersion(media_id)', 'CREATE INDEX IF NOT EXISTS idx_mediamodifications_media_id ON MediaModifications(media_id)', - 'CREATE INDEX IF NOT EXISTS idx_chatconversations_media_id ON ChatConversations(media_id)', - 'CREATE INDEX IF NOT EXISTS idx_chatmessages_conversation_id ON ChatMessages(conversation_id)', 'CREATE INDEX IF NOT EXISTS idx_media_is_trash ON Media(is_trash)', 'CREATE INDEX IF NOT EXISTS idx_mediachunks_media_id ON MediaChunks(media_id)', 'CREATE INDEX IF NOT EXISTS idx_unvectorized_media_chunks_media_id ON UnvectorizedMediaChunks(media_id)', @@ -606,7 +585,10 @@ def mark_media_as_processed(database, media_id): # Function to add media with keywords def add_media_with_keywords(url, title, media_type, content, keywords, prompt, summary, transcription_model, author, ingestion_date): + log_counter("add_media_with_keywords_attempt") + start_time = time.time() logging.debug(f"Entering add_media_with_keywords: URL={url}, Title={title}") + # Set default values for missing fields if url is None: url = 'localhost' @@ -622,10 +604,17 @@ def add_media_with_keywords(url, title, media_type, content, keywords, prompt, s author = author or 'Unknown' ingestion_date = ingestion_date or datetime.now().strftime('%Y-%m-%d') - if media_type not in ['article', 'audio', 'document', 'mediawiki_article', 'mediawiki_dump', 'obsidian_note', 'podcast', 'text', 'video', 'unknown']: - raise InputError("Invalid media type. Allowed types: article, audio file, document, obsidian_note podcast, text, video, unknown.") + if media_type not in ['article', 'audio', 'book', 'document', 'mediawiki_article', 'mediawiki_dump', + 'obsidian_note', 'podcast', 'text', 'video', 'unknown']: + log_counter("add_media_with_keywords_error", labels={"error_type": "InvalidMediaType"}) + duration = time.time() - start_time + log_histogram("add_media_with_keywords_duration", duration) + raise InputError("Invalid media type. Allowed types: article, audio file, document, obsidian_note, podcast, text, video, unknown.") if ingestion_date and not is_valid_date(ingestion_date): + log_counter("add_media_with_keywords_error", labels={"error_type": "InvalidDateFormat"}) + duration = time.time() - start_time + log_histogram("add_media_with_keywords_duration", duration) raise InputError("Invalid ingestion date format. Use YYYY-MM-DD.") # Handle keywords as either string or list @@ -654,6 +643,7 @@ def add_media_with_keywords(url, title, media_type, content, keywords, prompt, s logging.debug(f"Existing media ID for {url}: {existing_media_id}") if existing_media_id: + # Update existing media media_id = existing_media_id logging.debug(f"Updating existing media with ID: {media_id}") cursor.execute(''' @@ -661,7 +651,9 @@ def add_media_with_keywords(url, title, media_type, content, keywords, prompt, s SET content = ?, transcription_model = ?, type = ?, author = ?, ingestion_date = ? WHERE id = ? ''', (content, transcription_model, media_type, author, ingestion_date, media_id)) + log_counter("add_media_with_keywords_update") else: + # Insert new media logging.debug("Inserting new media") cursor.execute(''' INSERT INTO Media (url, title, type, content, author, ingestion_date, transcription_model) @@ -669,6 +661,7 @@ def add_media_with_keywords(url, title, media_type, content, keywords, prompt, s ''', (url, title, media_type, content, author, ingestion_date, transcription_model)) media_id = cursor.lastrowid logging.debug(f"New media inserted with ID: {media_id}") + log_counter("add_media_with_keywords_insert") cursor.execute(''' INSERT INTO MediaModifications (media_id, prompt, summary, modification_date) @@ -698,13 +691,23 @@ def add_media_with_keywords(url, title, media_type, content, keywords, prompt, s conn.commit() logging.info(f"Media '{title}' successfully added/updated with ID: {media_id}") - return media_id, f"Media '{title}' added/updated successfully with keywords: {', '.join(keyword_list)}" + duration = time.time() - start_time + log_histogram("add_media_with_keywords_duration", duration) + log_counter("add_media_with_keywords_success") + + return media_id, f"Media '{title}' added/updated successfully with keywords: {', '.join(keyword_list)}" except sqlite3.Error as e: logging.error(f"SQL Error in add_media_with_keywords: {e}") + duration = time.time() - start_time + log_histogram("add_media_with_keywords_duration", duration) + log_counter("add_media_with_keywords_error", labels={"error_type": "SQLiteError"}) raise DatabaseError(f"Error adding media with keywords: {e}") except Exception as e: logging.error(f"Unexpected Error in add_media_with_keywords: {e}") + duration = time.time() - start_time + log_histogram("add_media_with_keywords_duration", duration) + log_counter("add_media_with_keywords_error", labels={"error_type": type(e).__name__}) raise DatabaseError(f"Unexpected error: {e}") @@ -779,7 +782,13 @@ def ingest_article_to_db(url, title, author, content, keywords, summary, ingesti # Function to add a keyword def add_keyword(keyword: str) -> int: + log_counter("add_keyword_attempt") + start_time = time.time() + if not keyword.strip(): + log_counter("add_keyword_error", labels={"error_type": "EmptyKeyword"}) + duration = time.time() - start_time + log_histogram("add_keyword_duration", duration) raise DatabaseError("Keyword cannot be empty") keyword = keyword.strip().lower() @@ -801,18 +810,32 @@ def add_keyword(keyword: str) -> int: logging.info(f"Keyword '{keyword}' added or updated with ID: {keyword_id}") conn.commit() + + duration = time.time() - start_time + log_histogram("add_keyword_duration", duration) + log_counter("add_keyword_success") + return keyword_id except sqlite3.IntegrityError as e: logging.error(f"Integrity error adding keyword: {e}") + duration = time.time() - start_time + log_histogram("add_keyword_duration", duration) + log_counter("add_keyword_error", labels={"error_type": "IntegrityError"}) raise DatabaseError(f"Integrity error adding keyword: {e}") except sqlite3.Error as e: logging.error(f"Error adding keyword: {e}") + duration = time.time() - start_time + log_histogram("add_keyword_duration", duration) + log_counter("add_keyword_error", labels={"error_type": "SQLiteError"}) raise DatabaseError(f"Error adding keyword: {e}") # Function to delete a keyword def delete_keyword(keyword: str) -> str: + log_counter("delete_keyword_attempt") + start_time = time.time() + keyword = keyword.strip().lower() with db.get_connection() as conn: cursor = conn.cursor() @@ -823,10 +846,23 @@ def delete_keyword(keyword: str) -> str: cursor.execute('DELETE FROM Keywords WHERE keyword = ?', (keyword,)) cursor.execute('DELETE FROM keyword_fts WHERE rowid = ?', (keyword_id[0],)) conn.commit() + + duration = time.time() - start_time + log_histogram("delete_keyword_duration", duration) + log_counter("delete_keyword_success") + return f"Keyword '{keyword}' deleted successfully." else: + duration = time.time() - start_time + log_histogram("delete_keyword_duration", duration) + log_counter("delete_keyword_not_found") + return f"Keyword '{keyword}' not found." except sqlite3.Error as e: + duration = time.time() - start_time + log_histogram("delete_keyword_duration", duration) + log_counter("delete_keyword_error", labels={"error_type": type(e).__name__}) + logging.error(f"Error deleting keyword: {e}") raise DatabaseError(f"Error deleting keyword: {e}") @@ -1000,7 +1036,7 @@ def add_media_version(conn, media_id: int, prompt: str, summary: str) -> None: # Function to search the database with advanced options, including keyword search and full-text search -def sqlite_search_db(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 10, connection=None): +def search_media_db(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 20, connection=None): if page < 1: raise ValueError("Page number must be 1 or greater.") @@ -1055,7 +1091,7 @@ def execute_query(conn): # Gradio function to handle user input and display results with pagination, with better feedback def search_and_display(search_query, search_fields, keywords, page): - results = sqlite_search_db(search_query, search_fields, keywords, page) + results = search_media_db(search_query, search_fields, keywords, page) if isinstance(results, pd.DataFrame): # Convert DataFrame to a list of tuples or lists @@ -1133,7 +1169,7 @@ def format_results(results): # Function to export search results to CSV or markdown with pagination def export_to_file(search_query: str, search_fields: List[str], keyword: str, page: int = 1, results_per_file: int = 1000, export_format: str = 'csv'): try: - results = sqlite_search_db(search_query, search_fields, keyword, page, results_per_file) + results = search_media_db(search_query, search_fields, keyword, page, results_per_file) if not results: return "No results found to export." @@ -1381,303 +1417,6 @@ def schedule_chunking(media_id: int, content: str, media_name: str): ####################################################################################################################### -####################################################################################################################### -# -# Functions to manage prompts DB - -def create_prompts_db(): - logging.debug("create_prompts_db: Creating prompts database.") - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - cursor.executescript(''' - CREATE TABLE IF NOT EXISTS Prompts ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL UNIQUE, - author TEXT, - details TEXT, - system TEXT, - user TEXT - ); - CREATE TABLE IF NOT EXISTS Keywords ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - keyword TEXT NOT NULL UNIQUE COLLATE NOCASE - ); - CREATE TABLE IF NOT EXISTS PromptKeywords ( - prompt_id INTEGER, - keyword_id INTEGER, - FOREIGN KEY (prompt_id) REFERENCES Prompts (id), - FOREIGN KEY (keyword_id) REFERENCES Keywords (id), - PRIMARY KEY (prompt_id, keyword_id) - ); - CREATE INDEX IF NOT EXISTS idx_keywords_keyword ON Keywords(keyword); - CREATE INDEX IF NOT EXISTS idx_promptkeywords_prompt_id ON PromptKeywords(prompt_id); - CREATE INDEX IF NOT EXISTS idx_promptkeywords_keyword_id ON PromptKeywords(keyword_id); - ''') - -# FIXME - dirty hack that should be removed later... -# Migration function to add the 'author' column to the Prompts table -def add_author_column_to_prompts(): - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - # Check if 'author' column already exists - cursor.execute("PRAGMA table_info(Prompts)") - columns = [col[1] for col in cursor.fetchall()] - - if 'author' not in columns: - # Add the 'author' column - cursor.execute('ALTER TABLE Prompts ADD COLUMN author TEXT') - print("Author column added to Prompts table.") - else: - print("Author column already exists in Prompts table.") - -add_author_column_to_prompts() - -def normalize_keyword(keyword): - return re.sub(r'\s+', ' ', keyword.strip().lower()) - - -# FIXME - update calls to this function to use the new args -def add_prompt(name, author, details, system=None, user=None, keywords=None): - logging.debug(f"add_prompt: Adding prompt with name: {name}, author: {author}, system: {system}, user: {user}, keywords: {keywords}") - if not name: - logging.error("add_prompt: A name is required.") - return "A name is required." - - try: - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - cursor.execute(''' - INSERT INTO Prompts (name, author, details, system, user) - VALUES (?, ?, ?, ?, ?) - ''', (name, author, details, system, user)) - prompt_id = cursor.lastrowid - - if keywords: - normalized_keywords = [normalize_keyword(k) for k in keywords if k.strip()] - for keyword in set(normalized_keywords): # Use set to remove duplicates - cursor.execute(''' - INSERT OR IGNORE INTO Keywords (keyword) VALUES (?) - ''', (keyword,)) - cursor.execute('SELECT id FROM Keywords WHERE keyword = ?', (keyword,)) - keyword_id = cursor.fetchone()[0] - cursor.execute(''' - INSERT OR IGNORE INTO PromptKeywords (prompt_id, keyword_id) VALUES (?, ?) - ''', (prompt_id, keyword_id)) - return "Prompt added successfully." - except sqlite3.IntegrityError: - return "Prompt with this name already exists." - except sqlite3.Error as e: - return f"Database error: {e}" - - -def fetch_prompt_details(name): - logging.debug(f"fetch_prompt_details: Fetching details for prompt: {name}") - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - cursor.execute(''' - SELECT p.name, p.author, p.details, p.system, p.user, GROUP_CONCAT(k.keyword, ', ') as keywords - FROM Prompts p - LEFT JOIN PromptKeywords pk ON p.id = pk.prompt_id - LEFT JOIN Keywords k ON pk.keyword_id = k.id - WHERE p.name = ? - GROUP BY p.id - ''', (name,)) - return cursor.fetchone() - - -def list_prompts(page=1, per_page=10): - logging.debug(f"list_prompts: Listing prompts for page {page} with {per_page} prompts per page.") - offset = (page - 1) * per_page - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - cursor.execute('SELECT name FROM Prompts LIMIT ? OFFSET ?', (per_page, offset)) - prompts = [row[0] for row in cursor.fetchall()] - - # Get total count of prompts - cursor.execute('SELECT COUNT(*) FROM Prompts') - total_count = cursor.fetchone()[0] - - total_pages = (total_count + per_page - 1) // per_page - return prompts, total_pages, page - -# This will not scale. For a large number of prompts, use a more efficient method. -# FIXME - see above statement. -def load_preset_prompts(): - logging.debug("load_preset_prompts: Loading preset prompts.") - try: - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - cursor.execute('SELECT name FROM Prompts ORDER BY name ASC') - prompts = [row[0] for row in cursor.fetchall()] - return prompts - except sqlite3.Error as e: - print(f"Database error: {e}") - return [] - - -def insert_prompt_to_db(title, author, description, system_prompt, user_prompt, keywords=None): - return add_prompt(title, author, description, system_prompt, user_prompt, keywords) - - -def get_prompt_db_connection(): - prompt_db_path = get_database_path('prompts.db') - return sqlite3.connect(prompt_db_path) - - -def search_prompts(query): - logging.debug(f"search_prompts: Searching prompts with query: {query}") - try: - with get_prompt_db_connection() as conn: - cursor = conn.cursor() - cursor.execute(""" - SELECT p.name, p.details, p.system, p.user, GROUP_CONCAT(k.keyword, ', ') as keywords - FROM Prompts p - LEFT JOIN PromptKeywords pk ON p.id = pk.prompt_id - LEFT JOIN Keywords k ON pk.keyword_id = k.id - WHERE p.name LIKE ? OR p.details LIKE ? OR p.system LIKE ? OR p.user LIKE ? OR k.keyword LIKE ? - GROUP BY p.id - ORDER BY p.name - """, (f'%{query}%', f'%{query}%', f'%{query}%', f'%{query}%', f'%{query}%')) - return cursor.fetchall() - except sqlite3.Error as e: - logging.error(f"Error searching prompts: {e}") - return [] - - -def search_prompts_by_keyword(keyword, page=1, per_page=10): - logging.debug(f"search_prompts_by_keyword: Searching prompts by keyword: {keyword}") - normalized_keyword = normalize_keyword(keyword) - offset = (page - 1) * per_page - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - cursor.execute(''' - SELECT DISTINCT p.name - FROM Prompts p - JOIN PromptKeywords pk ON p.id = pk.prompt_id - JOIN Keywords k ON pk.keyword_id = k.id - WHERE k.keyword LIKE ? - LIMIT ? OFFSET ? - ''', ('%' + normalized_keyword + '%', per_page, offset)) - prompts = [row[0] for row in cursor.fetchall()] - - # Get total count of matching prompts - cursor.execute(''' - SELECT COUNT(DISTINCT p.id) - FROM Prompts p - JOIN PromptKeywords pk ON p.id = pk.prompt_id - JOIN Keywords k ON pk.keyword_id = k.id - WHERE k.keyword LIKE ? - ''', ('%' + normalized_keyword + '%',)) - total_count = cursor.fetchone()[0] - - total_pages = (total_count + per_page - 1) // per_page - return prompts, total_pages, page - - -def update_prompt_keywords(prompt_name, new_keywords): - logging.debug(f"update_prompt_keywords: Updating keywords for prompt: {prompt_name}") - try: - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - - cursor.execute('SELECT id FROM Prompts WHERE name = ?', (prompt_name,)) - prompt_id = cursor.fetchone() - if not prompt_id: - return "Prompt not found." - prompt_id = prompt_id[0] - - cursor.execute('DELETE FROM PromptKeywords WHERE prompt_id = ?', (prompt_id,)) - - normalized_keywords = [normalize_keyword(k) for k in new_keywords if k.strip()] - for keyword in set(normalized_keywords): # Use set to remove duplicates - cursor.execute('INSERT OR IGNORE INTO Keywords (keyword) VALUES (?)', (keyword,)) - cursor.execute('SELECT id FROM Keywords WHERE keyword = ?', (keyword,)) - keyword_id = cursor.fetchone()[0] - cursor.execute('INSERT INTO PromptKeywords (prompt_id, keyword_id) VALUES (?, ?)', - (prompt_id, keyword_id)) - - # Remove unused keywords - cursor.execute(''' - DELETE FROM Keywords - WHERE id NOT IN (SELECT DISTINCT keyword_id FROM PromptKeywords) - ''') - return "Keywords updated successfully." - except sqlite3.Error as e: - return f"Database error: {e}" - - -def add_or_update_prompt(title, author, description, system_prompt, user_prompt, keywords=None): - logging.debug(f"add_or_update_prompt: Adding or updating prompt: {title}") - if not title: - return "Error: Title is required." - - existing_prompt = fetch_prompt_details(title) - if existing_prompt: - # Update existing prompt - result = update_prompt_in_db(title, author, description, system_prompt, user_prompt) - if "successfully" in result: - # Update keywords if the prompt update was successful - keyword_result = update_prompt_keywords(title, keywords or []) - result += f" {keyword_result}" - else: - # Insert new prompt - result = insert_prompt_to_db(title, author, description, system_prompt, user_prompt, keywords) - - return result - - -def load_prompt_details(selected_prompt): - logging.debug(f"load_prompt_details: Loading prompt details for {selected_prompt}") - if selected_prompt: - details = fetch_prompt_details(selected_prompt) - if details: - return details[0], details[1], details[2], details[3], details[4], details[5] - return "", "", "", "", "", "" - - -def update_prompt_in_db(title, author, description, system_prompt, user_prompt): - logging.debug(f"update_prompt_in_db: Updating prompt: {title}") - try: - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - cursor.execute( - "UPDATE Prompts SET author = ?, details = ?, system = ?, user = ? WHERE name = ?", - (author, description, system_prompt, user_prompt, title) - ) - if cursor.rowcount == 0: - return "No prompt found with the given title." - return "Prompt updated successfully!" - except sqlite3.Error as e: - return f"Error updating prompt: {e}" - - -create_prompts_db() - -def delete_prompt(prompt_id): - logging.debug(f"delete_prompt: Deleting prompt with ID: {prompt_id}") - try: - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - - # Delete associated keywords - cursor.execute("DELETE FROM PromptKeywords WHERE prompt_id = ?", (prompt_id,)) - - # Delete the prompt - cursor.execute("DELETE FROM Prompts WHERE id = ?", (prompt_id,)) - - if cursor.rowcount == 0: - return f"No prompt found with ID {prompt_id}" - else: - conn.commit() - return f"Prompt with ID {prompt_id} has been successfully deleted" - except sqlite3.Error as e: - return f"An error occurred: {e}" - -# -# -####################################################################################################################### - - ####################################################################################################################### # # Function to fetch/update media content @@ -2020,204 +1759,6 @@ def import_obsidian_note_to_db(note_data): ####################################################################################################################### -####################################################################################################################### -# -# Chat-related Functions - - - -def create_chat_conversation(media_id, conversation_name): - try: - with db.get_connection() as conn: - cursor = conn.cursor() - cursor.execute(''' - INSERT INTO ChatConversations (media_id, conversation_name, created_at, updated_at) - VALUES (?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) - ''', (media_id, conversation_name)) - conn.commit() - return cursor.lastrowid - except sqlite3.Error as e: - logging.error(f"Error creating chat conversation: {e}") - raise DatabaseError(f"Error creating chat conversation: {e}") - - -def add_chat_message(conversation_id: int, sender: str, message: str) -> int: - try: - with db.get_connection() as conn: - cursor = conn.cursor() - cursor.execute(''' - INSERT INTO ChatMessages (conversation_id, sender, message) - VALUES (?, ?, ?) - ''', (conversation_id, sender, message)) - conn.commit() - return cursor.lastrowid - except sqlite3.Error as e: - logging.error(f"Error adding chat message: {e}") - raise DatabaseError(f"Error adding chat message: {e}") - - -def get_chat_messages(conversation_id: int) -> List[Dict[str, Any]]: - try: - with db.get_connection() as conn: - cursor = conn.cursor() - cursor.execute(''' - SELECT id, sender, message, timestamp - FROM ChatMessages - WHERE conversation_id = ? - ORDER BY timestamp ASC - ''', (conversation_id,)) - messages = cursor.fetchall() - return [ - { - 'id': msg[0], - 'sender': msg[1], - 'message': msg[2], - 'timestamp': msg[3] - } - for msg in messages - ] - except sqlite3.Error as e: - logging.error(f"Error retrieving chat messages: {e}") - raise DatabaseError(f"Error retrieving chat messages: {e}") - - -def search_chat_conversations(search_query: str) -> List[Dict[str, Any]]: - try: - with db.get_connection() as conn: - cursor = conn.cursor() - cursor.execute(''' - SELECT cc.id, cc.media_id, cc.conversation_name, cc.created_at, m.title as media_title - FROM ChatConversations cc - LEFT JOIN Media m ON cc.media_id = m.id - WHERE cc.conversation_name LIKE ? OR m.title LIKE ? - ORDER BY cc.updated_at DESC - ''', (f'%{search_query}%', f'%{search_query}%')) - conversations = cursor.fetchall() - return [ - { - 'id': conv[0], - 'media_id': conv[1], - 'conversation_name': conv[2], - 'created_at': conv[3], - 'media_title': conv[4] or "Unknown Media" - } - for conv in conversations - ] - except sqlite3.Error as e: - logging.error(f"Error searching chat conversations: {e}") - return [] - - -def update_chat_message(message_id: int, new_message: str) -> None: - try: - with db.get_connection() as conn: - cursor = conn.cursor() - cursor.execute(''' - UPDATE ChatMessages - SET message = ?, timestamp = CURRENT_TIMESTAMP - WHERE id = ? - ''', (new_message, message_id)) - conn.commit() - except sqlite3.Error as e: - logging.error(f"Error updating chat message: {e}") - raise DatabaseError(f"Error updating chat message: {e}") - - -def delete_chat_message(message_id: int) -> None: - try: - with db.get_connection() as conn: - cursor = conn.cursor() - cursor.execute('DELETE FROM ChatMessages WHERE id = ?', (message_id,)) - conn.commit() - except sqlite3.Error as e: - logging.error(f"Error deleting chat message: {e}") - raise DatabaseError(f"Error deleting chat message: {e}") - - -def save_chat_history_to_database(chatbot, conversation_id, media_id, media_name, conversation_name): - try: - with db.get_connection() as conn: - cursor = conn.cursor() - - # If conversation_id is None, create a new conversation - if conversation_id is None: - cursor.execute(''' - INSERT INTO ChatConversations (media_id, media_name, conversation_name, created_at, updated_at) - VALUES (?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) - ''', (media_id, media_name, conversation_name)) - conversation_id = cursor.lastrowid - else: - # If conversation exists, update the media_name - cursor.execute(''' - UPDATE ChatConversations - SET media_name = ?, updated_at = CURRENT_TIMESTAMP - WHERE id = ? - ''', (media_name, conversation_id)) - - # Save each message in the chatbot history - for i, (user_msg, ai_msg) in enumerate(chatbot): - cursor.execute(''' - INSERT INTO ChatMessages (conversation_id, sender, message, timestamp) - VALUES (?, ?, ?, CURRENT_TIMESTAMP) - ''', (conversation_id, 'user', user_msg)) - - cursor.execute(''' - INSERT INTO ChatMessages (conversation_id, sender, message, timestamp) - VALUES (?, ?, ?, CURRENT_TIMESTAMP) - ''', (conversation_id, 'ai', ai_msg)) - - # Update the conversation's updated_at timestamp - cursor.execute(''' - UPDATE ChatConversations - SET updated_at = CURRENT_TIMESTAMP - WHERE id = ? - ''', (conversation_id,)) - - conn.commit() - - return conversation_id - except Exception as e: - logging.error(f"Error saving chat history to database: {str(e)}") - raise - - -def get_conversation_name(conversation_id): - if conversation_id is None: - return None - - try: - with sqlite3.connect('media_summary.db') as conn: # Replace with your actual database name - cursor = conn.cursor() - - query = """ - SELECT conversation_name, media_name - FROM ChatConversations - WHERE id = ? - """ - - cursor.execute(query, (conversation_id,)) - result = cursor.fetchone() - - if result: - conversation_name, media_name = result - if conversation_name: - return conversation_name - elif media_name: - return f"{media_name}-chat" - - return None # Return None if no result found - except sqlite3.Error as e: - logging.error(f"Database error in get_conversation_name: {e}") - return None - except Exception as e: - logging.error(f"Unexpected error in get_conversation_name: {e}") - return None - -# -# End of Chat-related Functions -####################################################################################################################### - - ####################################################################################################################### # # Functions to Compare Transcripts @@ -2837,29 +2378,42 @@ def process_chunks(database, chunks: List[Dict], media_id: int, batch_size: int :param media_id: ID of the media these chunks belong to :param batch_size: Number of chunks to process in each batch """ + log_counter("process_chunks_attempt", labels={"media_id": media_id}) + start_time = time.time() total_chunks = len(chunks) processed_chunks = 0 - for i in range(0, total_chunks, batch_size): - batch = chunks[i:i + batch_size] - chunk_data = [ - (media_id, chunk['text'], chunk['start_index'], chunk['end_index']) - for chunk in batch - ] - - try: - database.execute_many( - "INSERT INTO MediaChunks (media_id, chunk_text, start_index, end_index) VALUES (?, ?, ?, ?)", - chunk_data - ) - processed_chunks += len(batch) - logging.info(f"Processed {processed_chunks}/{total_chunks} chunks for media_id {media_id}") - except Exception as e: - logging.error(f"Error inserting chunk batch for media_id {media_id}: {e}") - # Optionally, you could raise an exception here to stop processing - # raise + try: + for i in range(0, total_chunks, batch_size): + batch = chunks[i:i + batch_size] + chunk_data = [ + (media_id, chunk['text'], chunk['start_index'], chunk['end_index']) + for chunk in batch + ] - logging.info(f"Finished processing all {total_chunks} chunks for media_id {media_id}") + try: + database.execute_many( + "INSERT INTO MediaChunks (media_id, chunk_text, start_index, end_index) VALUES (?, ?, ?, ?)", + chunk_data + ) + processed_chunks += len(batch) + logging.info(f"Processed {processed_chunks}/{total_chunks} chunks for media_id {media_id}") + log_counter("process_chunks_batch_success", labels={"media_id": media_id}) + except Exception as e: + logging.error(f"Error inserting chunk batch for media_id {media_id}: {e}") + log_counter("process_chunks_batch_error", labels={"media_id": media_id, "error_type": type(e).__name__}) + # Optionally, you could raise an exception here to stop processing + # raise + + logging.info(f"Finished processing all {total_chunks} chunks for media_id {media_id}") + duration = time.time() - start_time + log_histogram("process_chunks_duration", duration, labels={"media_id": media_id}) + log_counter("process_chunks_success", labels={"media_id": media_id}) + except Exception as e: + duration = time.time() - start_time + log_histogram("process_chunks_duration", duration, labels={"media_id": media_id}) + log_counter("process_chunks_error", labels={"media_id": media_id, "error_type": type(e).__name__}) + logging.error(f"Error processing chunks for media_id {media_id}: {e}") # Usage example: @@ -2995,46 +2549,48 @@ def update_media_table(db): # # Workflow Functions +# Workflow Functions def save_workflow_chat_to_db(chat_history, workflow_name, conversation_id=None): - try: - with db.get_connection() as conn: - cursor = conn.cursor() - - if conversation_id is None: - # Create a new conversation - conversation_name = f"{workflow_name}_Workflow_{datetime.now().strftime('%Y%m%d_%H%M%S')}" - cursor.execute(''' - INSERT INTO ChatConversations (media_id, media_name, conversation_name, created_at, updated_at) - VALUES (NULL, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) - ''', (workflow_name, conversation_name)) - conversation_id = cursor.lastrowid - else: - # Update existing conversation - cursor.execute(''' - UPDATE ChatConversations - SET updated_at = CURRENT_TIMESTAMP - WHERE id = ? - ''', (conversation_id,)) - - # Save messages - for user_msg, ai_msg in chat_history: - if user_msg: - cursor.execute(''' - INSERT INTO ChatMessages (conversation_id, sender, message, timestamp) - VALUES (?, 'user', ?, CURRENT_TIMESTAMP) - ''', (conversation_id, user_msg)) - if ai_msg: - cursor.execute(''' - INSERT INTO ChatMessages (conversation_id, sender, message, timestamp) - VALUES (?, 'ai', ?, CURRENT_TIMESTAMP) - ''', (conversation_id, ai_msg)) - - conn.commit() - - return conversation_id, f"Chat saved successfully! Conversation ID: {conversation_id}" - except Exception as e: - logging.error(f"Error saving workflow chat to database: {str(e)}") - return None, f"Error saving chat to database: {str(e)}" + pass +# try: +# with db.get_connection() as conn: +# cursor = conn.cursor() +# +# if conversation_id is None: +# # Create a new conversation +# conversation_name = f"{workflow_name}_Workflow_{datetime.now().strftime('%Y%m%d_%H%M%S')}" +# cursor.execute(''' +# INSERT INTO ChatConversations (media_id, media_name, conversation_name, created_at, updated_at) +# VALUES (NULL, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) +# ''', (workflow_name, conversation_name)) +# conversation_id = cursor.lastrowid +# else: +# # Update existing conversation +# cursor.execute(''' +# UPDATE ChatConversations +# SET updated_at = CURRENT_TIMESTAMP +# WHERE id = ? +# ''', (conversation_id,)) +# +# # Save messages +# for user_msg, ai_msg in chat_history: +# if user_msg: +# cursor.execute(''' +# INSERT INTO ChatMessages (conversation_id, sender, message, timestamp) +# VALUES (?, 'user', ?, CURRENT_TIMESTAMP) +# ''', (conversation_id, user_msg)) +# if ai_msg: +# cursor.execute(''' +# INSERT INTO ChatMessages (conversation_id, sender, message, timestamp) +# VALUES (?, 'ai', ?, CURRENT_TIMESTAMP) +# ''', (conversation_id, ai_msg)) +# +# conn.commit() +# +# return conversation_id, f"Chat saved successfully! Conversation ID: {conversation_id}" +# except Exception as e: +# logging.error(f"Error saving workflow chat to database: {str(e)}") +# return None, f"Error saving chat to database: {str(e)}" def get_workflow_chat(conversation_id): diff --git a/App_Function_Libraries/Gradio_Related.py b/App_Function_Libraries/Gradio_Related.py index c2911040a..4f4e217d0 100644 --- a/App_Function_Libraries/Gradio_Related.py +++ b/App_Function_Libraries/Gradio_Related.py @@ -9,39 +9,43 @@ import logging import os import webbrowser - # # Import 3rd-Party Libraries import gradio as gr # # Local Imports -from App_Function_Libraries.DB.DB_Manager import get_db_config -from App_Function_Libraries.Gradio_UI.Anki_Validation_tab import create_anki_validation_tab +from App_Function_Libraries.DB.DB_Manager import get_db_config, backup_dir +from App_Function_Libraries.DB.RAG_QA_Chat_DB import create_tables +from App_Function_Libraries.Gradio_UI.Anki_tab import create_anki_validation_tab, create_anki_generator_tab from App_Function_Libraries.Gradio_UI.Arxiv_tab import create_arxiv_tab from App_Function_Libraries.Gradio_UI.Audio_ingestion_tab import create_audio_processing_tab +from App_Function_Libraries.Gradio_UI.Backup_RAG_Notes_Character_Chat_tab import create_database_management_interface from App_Function_Libraries.Gradio_UI.Book_Ingestion_tab import create_import_book_tab from App_Function_Libraries.Gradio_UI.Character_Chat_tab import create_character_card_interaction_tab, create_character_chat_mgmt_tab, create_custom_character_card_tab, \ create_character_card_validation_tab, create_export_characters_tab from App_Function_Libraries.Gradio_UI.Character_interaction_tab import create_narrator_controlled_conversation_tab, \ create_multiple_character_chat_tab -from App_Function_Libraries.Gradio_UI.Chat_ui import create_chat_management_tab, \ - create_chat_interface_four, create_chat_interface_multi_api, create_chat_interface_stacked, create_chat_interface +from App_Function_Libraries.Gradio_UI.Chat_ui import create_chat_interface_four, create_chat_interface_multi_api, \ + create_chat_interface_stacked, create_chat_interface from App_Function_Libraries.Gradio_UI.Config_tab import create_config_editor_tab from App_Function_Libraries.Gradio_UI.Explain_summarize_tab import create_summarize_explain_tab -from App_Function_Libraries.Gradio_UI.Export_Functionality import create_export_tab -from App_Function_Libraries.Gradio_UI.Backup_Functionality import create_backup_tab, create_view_backups_tab, \ - create_restore_backup_tab +from App_Function_Libraries.Gradio_UI.Export_Functionality import create_rag_export_tab, create_export_tabs +#from App_Function_Libraries.Gradio_UI.Backup_Functionality import create_backup_tab, create_view_backups_tab, \ +# create_restore_backup_tab from App_Function_Libraries.Gradio_UI.Import_Functionality import create_import_single_prompt_tab, \ - create_import_obsidian_vault_tab, create_import_item_tab, create_import_multiple_prompts_tab + create_import_obsidian_vault_tab, create_import_item_tab, create_import_multiple_prompts_tab, \ + create_conversation_import_tab from App_Function_Libraries.Gradio_UI.Introduction_tab import create_introduction_tab from App_Function_Libraries.Gradio_UI.Keywords import create_view_keywords_tab, create_add_keyword_tab, \ - create_delete_keyword_tab, create_export_keywords_tab + create_delete_keyword_tab, create_export_keywords_tab, create_rag_qa_keywords_tab, create_character_keywords_tab, \ + create_meta_keywords_tab, create_prompt_keywords_tab from App_Function_Libraries.Gradio_UI.Live_Recording import create_live_recording_tab from App_Function_Libraries.Gradio_UI.Llamafile_tab import create_chat_with_llamafile_tab #from App_Function_Libraries.Gradio_UI.MMLU_Pro_tab import create_mmlu_pro_tab from App_Function_Libraries.Gradio_UI.Media_edit import create_prompt_clone_tab, create_prompt_edit_tab, \ create_media_edit_and_clone_tab, create_media_edit_tab from App_Function_Libraries.Gradio_UI.Media_wiki_tab import create_mediawiki_import_tab, create_mediawiki_config_tab +from App_Function_Libraries.Gradio_UI.Mind_Map_tab import create_mindmap_tab from App_Function_Libraries.Gradio_UI.PDF_ingestion_tab import create_pdf_ingestion_tab, create_pdf_ingestion_test_tab from App_Function_Libraries.Gradio_UI.Plaintext_tab_import import create_plain_text_import_tab from App_Function_Libraries.Gradio_UI.Podcast_tab import create_podcast_tab @@ -62,16 +66,19 @@ from App_Function_Libraries.Gradio_UI.Video_transcription_tab import create_video_transcription_tab from App_Function_Libraries.Gradio_UI.View_tab import create_manage_items_tab from App_Function_Libraries.Gradio_UI.Website_scraping_tab import create_website_scraping_tab -from App_Function_Libraries.Gradio_UI.Chat_Workflows import chat_workflows_tab -from App_Function_Libraries.Gradio_UI.View_DB_Items_tab import create_prompt_view_tab, \ - create_view_all_mediadb_with_versions_tab, create_viewing_mediadb_tab, create_view_all_rag_notes_tab, \ - create_viewing_ragdb_tab, create_mediadb_keyword_search_tab, create_ragdb_keyword_items_tab +from App_Function_Libraries.Gradio_UI.Workflows_tab import chat_workflows_tab +from App_Function_Libraries.Gradio_UI.View_DB_Items_tab import create_view_all_mediadb_with_versions_tab, \ + create_viewing_mediadb_tab, create_view_all_rag_notes_tab, create_viewing_ragdb_tab, \ + create_mediadb_keyword_search_tab, create_ragdb_keyword_items_tab +from App_Function_Libraries.Gradio_UI.Prompts_tab import create_prompt_view_tab, create_prompts_export_tab # # Gradio UI Imports from App_Function_Libraries.Gradio_UI.Evaluations_Benchmarks_tab import create_geval_tab, create_infinite_bench_tab from App_Function_Libraries.Gradio_UI.XML_Ingestion_Tab import create_xml_import_tab #from App_Function_Libraries.Local_LLM.Local_LLM_huggingface import create_huggingface_tab from App_Function_Libraries.Local_LLM.Local_LLM_ollama import create_ollama_tab +from App_Function_Libraries.Utils.Utils import load_and_log_configs + # ####################################################################################################################### # Function Definitions @@ -235,6 +242,147 @@ # all_prompts2 = prompts_category_1 + prompts_category_2 + +####################################################################################################################### +# +# Migration Script +import sqlite3 +import uuid +import logging +import os +from datetime import datetime +import shutil + +# def migrate_media_db_to_rag_chat_db(media_db_path, rag_chat_db_path): +# # Check if migration is needed +# if not os.path.exists(media_db_path): +# logging.info("Media DB does not exist. No migration needed.") +# return +# +# # Optional: Check if migration has already been completed +# migration_flag = os.path.join(os.path.dirname(rag_chat_db_path), 'migration_completed.flag') +# if os.path.exists(migration_flag): +# logging.info("Migration already completed. Skipping migration.") +# return +# +# # Backup databases +# backup_database(media_db_path) +# backup_database(rag_chat_db_path) +# +# # Connect to both databases +# try: +# media_conn = sqlite3.connect(media_db_path) +# rag_conn = sqlite3.connect(rag_chat_db_path) +# +# # Enable foreign key support +# media_conn.execute('PRAGMA foreign_keys = ON;') +# rag_conn.execute('PRAGMA foreign_keys = ON;') +# +# media_cursor = media_conn.cursor() +# rag_cursor = rag_conn.cursor() +# +# # Begin transaction +# rag_conn.execute('BEGIN TRANSACTION;') +# +# # Extract conversations from media DB +# media_cursor.execute(''' +# SELECT id, media_id, media_name, conversation_name, created_at, updated_at +# FROM ChatConversations +# ''') +# conversations = media_cursor.fetchall() +# +# for conv in conversations: +# old_conv_id, media_id, media_name, conversation_name, created_at, updated_at = conv +# +# # Convert timestamps if necessary +# created_at = parse_timestamp(created_at) +# updated_at = parse_timestamp(updated_at) +# +# # Generate a new conversation_id +# conversation_id = str(uuid.uuid4()) +# title = conversation_name or (f"{media_name}-chat" if media_name else "Untitled Conversation") +# +# # Insert into conversation_metadata +# rag_cursor.execute(''' +# INSERT INTO conversation_metadata (conversation_id, created_at, last_updated, title, media_id) +# VALUES (?, ?, ?, ?, ?) +# ''', (conversation_id, created_at, updated_at, title, media_id)) +# +# # Extract messages from media DB +# media_cursor.execute(''' +# SELECT sender, message, timestamp +# FROM ChatMessages +# WHERE conversation_id = ? +# ORDER BY timestamp ASC +# ''', (old_conv_id,)) +# messages = media_cursor.fetchall() +# +# for msg in messages: +# sender, content, timestamp = msg +# +# # Convert timestamp if necessary +# timestamp = parse_timestamp(timestamp) +# +# role = sender # Assuming 'sender' is 'user' or 'ai' +# +# # Insert message into rag_qa_chats +# rag_cursor.execute(''' +# INSERT INTO rag_qa_chats (conversation_id, timestamp, role, content) +# VALUES (?, ?, ?, ?) +# ''', (conversation_id, timestamp, role, content)) +# +# # Commit transaction +# rag_conn.commit() +# logging.info("Migration completed successfully.") +# +# # Mark migration as complete +# with open(migration_flag, 'w') as f: +# f.write('Migration completed on ' + datetime.now().isoformat()) +# +# except Exception as e: +# # Rollback transaction in case of error +# rag_conn.rollback() +# logging.error(f"Error during migration: {e}") +# raise +# finally: +# media_conn.close() +# rag_conn.close() + +def backup_database(db_path): + backup_path = db_path + '.backup' + if not os.path.exists(backup_path): + shutil.copyfile(db_path, backup_path) + logging.info(f"Database backed up to {backup_path}") + else: + logging.info(f"Backup already exists at {backup_path}") + +def parse_timestamp(timestamp_value): + """ + Parses the timestamp from the old database and converts it to a standard format. + Adjust this function based on the actual format of your timestamps. + """ + try: + # Attempt to parse ISO format + return datetime.fromisoformat(timestamp_value).isoformat() + except ValueError: + # Handle other timestamp formats if necessary + # For example, if timestamps are in Unix epoch format + try: + timestamp_float = float(timestamp_value) + return datetime.fromtimestamp(timestamp_float).isoformat() + except ValueError: + # Default to current time if parsing fails + logging.warning(f"Unable to parse timestamp '{timestamp_value}', using current time.") + return datetime.now().isoformat() + +# +# End of Migration Script +####################################################################################################################### + + +####################################################################################################################### +# +# Launch UI Function def launch_ui(share_public=None, server_mode=False): webbrowser.open_new_tab('http://127.0.0.1:7860/?__theme=dark') share=share_public @@ -257,6 +405,19 @@ def launch_ui(share_public=None, server_mode=False): } """ + config = load_and_log_configs() + # Get database paths from config + db_config = config['db_config'] + media_db_path = db_config['sqlite_path'] + character_chat_db_path = os.path.join(os.path.dirname(media_db_path), "chatDB.db") + rag_chat_db_path = os.path.join(os.path.dirname(media_db_path), "rag_qa.db") + # Initialize the RAG Chat DB (create tables and update schema) + create_tables() + + # Migrate data from the media DB to the RAG Chat DB + #migrate_media_db_to_rag_chat_db(media_db_path, rag_chat_db_path) + + with gr.Blocks(theme='bethecloud/storj_theme',css=css) as iface: gr.HTML( """ @@ -290,10 +451,6 @@ def launch_ui(share_public=None, server_mode=False): create_arxiv_tab() create_semantic_scholar_tab() - with gr.TabItem("Text Search", id="text search", visible=True): - create_search_tab() - create_search_summaries_tab() - with gr.TabItem("RAG Chat/Search", id="RAG Chat Notes group", visible=True): create_rag_tab() create_rag_qa_chat_tab() @@ -305,8 +462,6 @@ def launch_ui(share_public=None, server_mode=False): create_chat_interface_stacked() create_chat_interface_multi_api() create_chat_interface_four() - create_chat_with_llamafile_tab() - create_chat_management_tab() chat_workflows_tab() with gr.TabItem("Character Chat", id="character chat group", visible=True): @@ -318,51 +473,56 @@ def launch_ui(share_public=None, server_mode=False): create_narrator_controlled_conversation_tab() create_export_characters_tab() - with gr.TabItem("View DB Items", id="view db items group", visible=True): + with gr.TabItem("Writing Tools", id="writing_tools group", visible=True): + from App_Function_Libraries.Gradio_UI.Writing_tab import create_document_feedback_tab + create_document_feedback_tab() + from App_Function_Libraries.Gradio_UI.Writing_tab import create_grammar_style_check_tab + create_grammar_style_check_tab() + from App_Function_Libraries.Gradio_UI.Writing_tab import create_tone_adjustment_tab + create_tone_adjustment_tab() + from App_Function_Libraries.Gradio_UI.Writing_tab import create_creative_writing_tab + create_creative_writing_tab() + from App_Function_Libraries.Gradio_UI.Writing_tab import create_mikupad_tab + create_mikupad_tab() + + with gr.TabItem("Search/View DB Items", id="view db items group", visible=True): + create_search_tab() + create_search_summaries_tab() create_view_all_mediadb_with_versions_tab() create_viewing_mediadb_tab() create_mediadb_keyword_search_tab() create_view_all_rag_notes_tab() create_viewing_ragdb_tab() create_ragdb_keyword_items_tab() - create_prompt_view_tab() with gr.TabItem("Prompts", id='view prompts group', visible=True): - create_prompt_view_tab() - create_prompt_search_tab() - create_prompt_edit_tab() - create_prompt_clone_tab() - create_prompt_suggestion_tab() - - with gr.TabItem("Manage / Edit Existing Items", id="manage group", visible=True): + with gr.Tabs(): + create_prompt_view_tab() + create_prompt_search_tab() + create_prompt_edit_tab() + create_prompt_clone_tab() + create_prompt_suggestion_tab() + create_prompts_export_tab() + + with gr.TabItem("Manage Media DB Items", id="manage group", visible=True): create_media_edit_tab() create_manage_items_tab() create_media_edit_and_clone_tab() - # FIXME - #create_compare_transcripts_tab() with gr.TabItem("Embeddings Management", id="embeddings group", visible=True): create_embeddings_tab() create_view_embeddings_tab() create_purge_embeddings_tab() - with gr.TabItem("Writing Tools", id="writing_tools group", visible=True): - from App_Function_Libraries.Gradio_UI.Writing_tab import create_document_feedback_tab - create_document_feedback_tab() - from App_Function_Libraries.Gradio_UI.Writing_tab import create_grammar_style_check_tab - create_grammar_style_check_tab() - from App_Function_Libraries.Gradio_UI.Writing_tab import create_tone_adjustment_tab - create_tone_adjustment_tab() - from App_Function_Libraries.Gradio_UI.Writing_tab import create_creative_writing_tab - create_creative_writing_tab() - from App_Function_Libraries.Gradio_UI.Writing_tab import create_mikupad_tab - create_mikupad_tab() - with gr.TabItem("Keywords", id="keywords group", visible=True): create_view_keywords_tab() create_add_keyword_tab() create_delete_keyword_tab() create_export_keywords_tab() + create_character_keywords_tab() + create_rag_qa_keywords_tab() + create_meta_keywords_tab() + create_prompt_keywords_tab() with gr.TabItem("Import", id="import group", visible=True): create_import_item_tab() @@ -371,23 +531,38 @@ def launch_ui(share_public=None, server_mode=False): create_import_multiple_prompts_tab() create_mediawiki_import_tab() create_mediawiki_config_tab() + create_conversation_import_tab() with gr.TabItem("Export", id="export group", visible=True): - create_export_tab() - - with gr.TabItem("Backup Management", id="backup group", visible=True): - create_backup_tab() - create_view_backups_tab() - create_restore_backup_tab() + create_export_tabs() + + + with gr.TabItem("Database Management", id="database_management_group", visible=True): + create_database_management_interface( + media_db_config={ + 'db_path': media_db_path, + 'backup_dir': backup_dir + }, + rag_db_config={ + 'db_path': rag_chat_db_path, + 'backup_dir': backup_dir + }, + char_db_config={ + 'db_path': character_chat_db_path, + 'backup_dir': backup_dir + } + ) with gr.TabItem("Utilities", id="util group", visible=True): - # FIXME - #create_anki_generation_tab() - create_anki_validation_tab() + create_mindmap_tab() create_utilities_yt_video_tab() create_utilities_yt_audio_tab() create_utilities_yt_timestamp_tab() + with gr.TabItem("Anki Deck Creation/Validation", id="anki group", visible=True): + create_anki_generator_tab() + create_anki_validation_tab() + with gr.TabItem("Local LLM", id="local llm group", visible=True): create_chat_with_llamafile_tab() create_ollama_tab() diff --git a/App_Function_Libraries/Gradio_UI/Anki_Validation_tab.py b/App_Function_Libraries/Gradio_UI/Anki_Validation_tab.py deleted file mode 100644 index 7560d97b5..000000000 --- a/App_Function_Libraries/Gradio_UI/Anki_Validation_tab.py +++ /dev/null @@ -1,836 +0,0 @@ -# Anki_Validation_tab.py -# Description: Gradio functions for the Anki Validation tab -# -# Imports -from datetime import datetime -import base64 -import json -import logging -import os -from pathlib import Path -import shutil -import sqlite3 -import tempfile -from typing import Dict, Any, Optional, Tuple, List -import zipfile -# -# External Imports -import gradio as gr -#from outlines import models, prompts -# -# Local Imports -from App_Function_Libraries.Gradio_UI.Chat_ui import chat_wrapper -from App_Function_Libraries.Third_Party.Anki import sanitize_html, generate_card_choices, \ - export_cards, load_card_for_editing, validate_flashcards, handle_file_upload, \ - validate_for_ui, update_card_with_validation, update_card_choices, format_validation_result, enhanced_file_upload, \ - handle_validation -from App_Function_Libraries.Utils.Utils import default_api_endpoint, format_api_name, global_api_endpoints -# -############################################################################################################ -# -# Functions: - -# def create_anki_generation_tab(): -# try: -# default_value = None -# if default_api_endpoint: -# if default_api_endpoint in global_api_endpoints: -# default_value = format_api_name(default_api_endpoint) -# else: -# logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") -# except Exception as e: -# logging.error(f"Error setting default API endpoint: {str(e)}") -# default_value = None -# with gr.TabItem("Anki Flashcard Generation", visible=True): -# gr.Markdown("# Anki Flashcard Generation") -# chat_history = gr.State([]) -# generated_cards_state = gr.State({}) -# -# # Add progress tracking -# generation_progress = gr.Progress() -# status_message = gr.Status() -# -# with gr.Row(): -# # Left Column: Generation Controls -# with gr.Column(scale=1): -# gr.Markdown("## Content Input") -# source_text = gr.TextArea( -# label="Source Text or Topic", -# placeholder="Enter the text or topic you want to create flashcards from...", -# lines=5 -# ) -# -# # API Configuration -# api_endpoint = gr.Dropdown( -# choices=["None"] + [format_api_name(api) for api in global_api_endpoints], -# value=default_value, -# label="API for Card Generation" -# ) -# api_key = gr.Textbox(label="API Key (if required)", type="password") -# -# with gr.Accordion("Generation Settings", open=True): -# num_cards = gr.Slider( -# minimum=1, -# maximum=20, -# value=5, -# step=1, -# label="Number of Cards" -# ) -# -# card_types = gr.CheckboxGroup( -# choices=["basic", "cloze", "reverse"], -# value=["basic"], -# label="Card Types to Generate" -# ) -# -# difficulty_level = gr.Radio( -# choices=["beginner", "intermediate", "advanced"], -# value="intermediate", -# label="Difficulty Level" -# ) -# -# subject_area = gr.Dropdown( -# choices=[ -# "general", -# "language_learning", -# "science", -# "mathematics", -# "history", -# "geography", -# "computer_science", -# "custom" -# ], -# value="general", -# label="Subject Area" -# ) -# -# custom_subject = gr.Textbox( -# label="Custom Subject", -# visible=False, -# placeholder="Enter custom subject..." -# ) -# -# with gr.Accordion("Advanced Options", open=False): -# temperature = gr.Slider( -# label="Temperature", -# minimum=0.00, -# maximum=1.0, -# step=0.05, -# value=0.7 -# ) -# -# max_retries = gr.Slider( -# label="Max Retries on Error", -# minimum=1, -# maximum=5, -# step=1, -# value=3 -# ) -# -# include_examples = gr.Checkbox( -# label="Include example usage", -# value=True -# ) -# -# include_mnemonics = gr.Checkbox( -# label="Generate mnemonics", -# value=True -# ) -# -# include_hints = gr.Checkbox( -# label="Include hints", -# value=True -# ) -# -# tag_style = gr.Radio( -# choices=["broad", "specific", "hierarchical"], -# value="specific", -# label="Tag Style" -# ) -# -# system_prompt = gr.Textbox( -# label="System Prompt", -# value="You are an expert at creating effective Anki flashcards.", -# lines=2 -# ) -# -# generate_button = gr.Button("Generate Flashcards") -# regenerate_button = gr.Button("Regenerate", visible=False) -# error_log = gr.TextArea( -# label="Error Log", -# visible=False, -# lines=3 -# ) -# -# # Right Column: Chat Interface and Preview -# with gr.Column(scale=1): -# gr.Markdown("## Interactive Card Generation") -# chatbot = gr.Chatbot(height=400, elem_classes="chatbot-container") -# -# with gr.Row(): -# msg = gr.Textbox( -# label="Chat to refine cards", -# placeholder="Ask questions or request modifications..." -# ) -# submit_chat = gr.Button("Submit") -# -# gr.Markdown("## Generated Cards Preview") -# generated_cards = gr.JSON(label="Generated Flashcards") -# -# with gr.Row(): -# edit_generated = gr.Button("Edit in Validator") -# save_generated = gr.Button("Save to File") -# clear_chat = gr.Button("Clear Chat") -# -# generation_status = gr.Markdown("") -# download_file = gr.File(label="Download Cards", visible=False) -# -# # Helper Functions and Classes -# class AnkiCardGenerator: -# def __init__(self): -# self.schema = { -# "type": "object", -# "properties": { -# "cards": { -# "type": "array", -# "items": { -# "type": "object", -# "properties": { -# "id": {"type": "string"}, -# "type": {"type": "string", "enum": ["basic", "cloze", "reverse"]}, -# "front": {"type": "string"}, -# "back": {"type": "string"}, -# "tags": { -# "type": "array", -# "items": {"type": "string"} -# }, -# "note": {"type": "string"} -# }, -# "required": ["id", "type", "front", "back", "tags"] -# } -# } -# }, -# "required": ["cards"] -# } -# -# self.template = prompts.TextTemplate(""" -# Generate {num_cards} Anki flashcards about: {text} -# -# Requirements: -# - Difficulty: {difficulty} -# - Subject: {subject} -# - Card Types: {card_types} -# - Include Examples: {include_examples} -# - Include Mnemonics: {include_mnemonics} -# - Include Hints: {include_hints} -# - Tag Style: {tag_style} -# -# Each card must have: -# 1. Unique ID starting with CARD_ -# 2. Type (one of: basic, cloze, reverse) -# 3. Clear question/prompt on front -# 4. Comprehensive answer on back -# 5. Relevant tags including subject and difficulty -# 6. Optional note with study tips or mnemonics -# -# For cloze deletions, use the format {{c1::text to be hidden}}. -# -# Ensure each card: -# - Focuses on a single concept -# - Is clear and unambiguous -# - Uses appropriate formatting -# - Has relevant tags -# - Includes requested additional information -# """) -# -# async def generate_with_progress( -# self, -# text: str, -# config: Dict[str, Any], -# progress: gr.Progress -# ) -> GenerationResult: -# try: -# # Initialize progress -# progress(0, desc="Initializing generation...") -# -# # Configure model -# model = models.Model(config["api_endpoint"]) -# -# # Generate with schema validation -# progress(0.3, desc="Generating cards...") -# response = await model.generate( -# self.template, -# schema=self.schema, -# text=text, -# **config -# ) -# -# # Validate response -# progress(0.6, desc="Validating generated cards...") -# validated_cards = self.validate_cards(response) -# -# # Final processing -# progress(0.9, desc="Finalizing...") -# time.sleep(0.5) # Brief pause for UI feedback -# return GenerationResult( -# cards=validated_cards, -# error=None, -# status="Generation completed successfully!", -# progress=1.0 -# ) -# -# except Exception as e: -# logging.error(f"Card generation error: {str(e)}") -# return GenerationResult( -# cards=None, -# error=str(e), -# status=f"Error: {str(e)}", -# progress=1.0 -# ) -# -# def validate_cards(self, cards: Dict[str, Any]) -> Dict[str, Any]: -# """Validate and clean generated cards""" -# if not isinstance(cards, dict) or "cards" not in cards: -# raise ValueError("Invalid card format") -# -# seen_ids = set() -# cleaned_cards = [] -# -# for card in cards["cards"]: -# # Check ID uniqueness -# if card["id"] in seen_ids: -# card["id"] = f"{card['id']}_{len(seen_ids)}" -# seen_ids.add(card["id"]) -# -# # Validate card type -# if card["type"] not in ["basic", "cloze", "reverse"]: -# raise ValueError(f"Invalid card type: {card['type']}") -# -# # Check content -# if not card["front"].strip() or not card["back"].strip(): -# raise ValueError("Empty card content") -# -# # Validate cloze format -# if card["type"] == "cloze" and "{{c1::" not in card["front"]: -# raise ValueError("Invalid cloze format") -# -# # Clean and standardize tags -# if not isinstance(card["tags"], list): -# card["tags"] = [str(card["tags"])] -# card["tags"] = [tag.strip().lower() for tag in card["tags"] if tag.strip()] -# -# cleaned_cards.append(card) -# -# return {"cards": cleaned_cards} -# -# # Initialize generator -# generator = AnkiCardGenerator() -# -# async def generate_flashcards(*args): -# text, num_cards, card_types, difficulty, subject, custom_subject, \ -# include_examples, include_mnemonics, include_hints, tag_style, \ -# temperature, api_endpoint, api_key, system_prompt, max_retries = args -# -# actual_subject = custom_subject if subject == "custom" else subject -# -# config = { -# "num_cards": num_cards, -# "difficulty": difficulty, -# "subject": actual_subject, -# "card_types": card_types, -# "include_examples": include_examples, -# "include_mnemonics": include_mnemonics, -# "include_hints": include_hints, -# "tag_style": tag_style, -# "temperature": temperature, -# "api_endpoint": api_endpoint, -# "api_key": api_key, -# "system_prompt": system_prompt -# } -# -# errors = [] -# retry_count = 0 -# -# while retry_count < max_retries: -# try: -# result = await generator.generate_with_progress(text, config, generation_progress) -# -# if result.error: -# errors.append(f"Attempt {retry_count + 1}: {result.error}") -# retry_count += 1 -# await asyncio.sleep(1) -# continue -# -# return ( -# result.cards, -# gr.update(visible=True), -# result.status, -# gr.update(visible=False), -# [[None, "Cards generated! You can now modify them through chat."]] -# ) -# -# except Exception as e: -# errors.append(f"Attempt {retry_count + 1}: {str(e)}") -# retry_count += 1 -# await asyncio.sleep(1) -# -# error_log = "\n".join(errors) -# return ( -# None, -# gr.update(visible=False), -# "Failed to generate cards after all retries", -# gr.update(value=error_log, visible=True), -# [[None, "Failed to generate cards. Please check the error log."]] -# ) -# -# def save_generated_cards(cards): -# if not cards: -# return "No cards to save", None -# -# try: -# cards_json = json.dumps(cards, indent=2) -# current_time = datetime.now().strftime("%Y%m%d_%H%M%S") -# filename = f"anki_cards_{current_time}.json" -# -# return ( -# "Cards saved successfully!", -# (filename, cards_json, "application/json") -# ) -# except Exception as e: -# logging.error(f"Error saving cards: {e}") -# return f"Error saving cards: {str(e)}", None -# -# def clear_chat_history(): -# return [], [], "Chat cleared" -# -# def toggle_custom_subject(choice): -# return gr.update(visible=choice == "custom") -# -# def send_to_validator(cards): -# if not cards: -# return "No cards to validate" -# try: -# # Here you would integrate with your validation tab -# validated_cards = generator.validate_cards(cards) -# return "Cards validated and sent to validator" -# except Exception as e: -# logging.error(f"Validation error: {e}") -# return f"Validation error: {str(e)}" -# -# # Register callbacks -# subject_area.change( -# fn=toggle_custom_subject, -# inputs=subject_area, -# outputs=custom_subject -# ) -# -# generate_button.click( -# fn=generate_flashcards, -# inputs=[ -# source_text, num_cards, card_types, difficulty_level, -# subject_area, custom_subject, include_examples, -# include_mnemonics, include_hints, tag_style, -# temperature, api_endpoint, api_key, system_prompt, -# max_retries -# ], -# outputs=[ -# generated_cards, -# regenerate_button, -# generation_status, -# error_log, -# chatbot -# ] -# ) -# -# regenerate_button.click( -# fn=generate_flashcards, -# inputs=[ -# source_text, num_cards, card_types, difficulty_level, -# subject_area, custom_subject, include_examples, -# include_mnemonics, include_hints, tag_style, -# temperature, api_endpoint, api_key, system_prompt, -# max_retries -# ], -# outputs=[ -# generated_cards, -# regenerate_button, -# generation_status, -# error_log, -# chatbot -# ] -# ) -# -# clear_chat.click( -# fn=clear_chat_history, -# outputs=[chatbot, chat_history, generation_status] -# ) -# -# edit_generated.click( -# fn=send_to_validator, -# inputs=generated_cards, -# outputs=generation_status -# ) -# -# save_generated.click( -# fn=save_generated_cards, -# inputs=generated_cards, -# outputs=[generation_status, download_file] -# ) -# -# return ( -# source_text, num_cards, card_types, difficulty_level, -# subject_area, custom_subject, include_examples, -# include_mnemonics, include_hints, tag_style, -# api_endpoint, api_key, temperature, system_prompt, -# generate_button, regenerate_button, generated_cards, -# edit_generated, save_generated, clear_chat, -# generation_status, chatbot, msg, submit_chat, -# chat_history, generated_cards_state, download_file, -# error_log, max_retries -# ) - -def create_anki_validation_tab(): - with gr.TabItem("Anki Flashcard Validation", visible=True): - gr.Markdown("# Anki Flashcard Validation and Editor") - - # State variables for internal tracking - current_card_data = gr.State({}) - preview_update_flag = gr.State(False) - - with gr.Row(): - # Left Column: Input and Validation - with gr.Column(scale=1): - gr.Markdown("## Import or Create Flashcards") - - input_type = gr.Radio( - choices=["JSON", "APKG"], - label="Input Type", - value="JSON" - ) - - with gr.Group() as json_input_group: - flashcard_input = gr.TextArea( - label="Enter Flashcards (JSON format)", - placeholder='''{ - "cards": [ - { - "id": "CARD_001", - "type": "basic", - "front": "What is the capital of France?", - "back": "Paris", - "tags": ["geography", "europe"], - "note": "Remember: City of Light" - } - ] -}''', - lines=10 - ) - - import_json = gr.File( - label="Or Import JSON File", - file_types=[".json"] - ) - - with gr.Group(visible=False) as apkg_input_group: - import_apkg = gr.File( - label="Import APKG File", - file_types=[".apkg"] - ) - deck_info = gr.JSON( - label="Deck Information", - visible=False - ) - - validate_button = gr.Button("Validate Flashcards") - - # Right Column: Validation Results and Editor - with gr.Column(scale=1): - gr.Markdown("## Validation Results") - validation_status = gr.Markdown("") - - with gr.Accordion("Validation Rules", open=False): - gr.Markdown(""" - ### Required Fields: - - Unique ID - - Card Type (basic, cloze, reverse) - - Front content - - Back content - - At least one tag - - ### Content Rules: - - No empty fields - - Front side should be a clear question/prompt - - Back side should contain complete answer - - Cloze deletions must have valid syntax - - No duplicate IDs - - ### Image Rules: - - Valid image tags - - Supported formats (JPG, PNG, GIF) - - Base64 encoded or valid URL - - ### APKG-specific Rules: - - Valid SQLite database structure - - Media files properly referenced - - Note types match Anki standards - - Card templates are well-formed - """) - - with gr.Row(): - # Card Editor - gr.Markdown("## Card Editor") - with gr.Row(): - with gr.Column(scale=1): - with gr.Accordion("Edit Individual Cards", open=True): - card_selector = gr.Dropdown( - label="Select Card to Edit", - choices=[], - interactive=True - ) - - card_type = gr.Radio( - choices=["basic", "cloze", "reverse"], - label="Card Type", - value="basic" - ) - - # Front content with preview - with gr.Group(): - gr.Markdown("### Front Content") - front_content = gr.TextArea( - label="Content (HTML supported)", - lines=3 - ) - front_preview = gr.HTML( - label="Preview" - ) - - # Back content with preview - with gr.Group(): - gr.Markdown("### Back Content") - back_content = gr.TextArea( - label="Content (HTML supported)", - lines=3 - ) - back_preview = gr.HTML( - label="Preview" - ) - - tags_input = gr.TextArea( - label="Tags (comma-separated)", - lines=1 - ) - - notes_input = gr.TextArea( - label="Additional Notes", - lines=2 - ) - - with gr.Row(): - update_card_button = gr.Button("Update Card") - delete_card_button = gr.Button("Delete Card", variant="stop") - - with gr.Row(): - with gr.Column(scale=1): - # Export Options - gr.Markdown("## Export Options") - export_format = gr.Radio( - choices=["Anki CSV", "JSON", "Plain Text"], - label="Export Format", - value="Anki CSV" - ) - export_button = gr.Button("Export Valid Cards") - export_file = gr.File(label="Download Validated Cards") - export_status = gr.Markdown("") - with gr.Column(scale=1): - gr.Markdown("## Export Instructions") - gr.Markdown(""" - ### Anki CSV Format: - - Front, Back, Tags, Type, Note - - Use for importing into Anki - - Images preserved as HTML - - ### JSON Format: - - JSON array of cards - - Images as base64 or URLs - - Use for custom processing - - ### Plain Text Format: - - Question and Answer pairs - - Images represented as [IMG] placeholder - - Use for manual review - """) - - def update_preview(content): - """Update preview with sanitized content.""" - if not content: - return "" - return sanitize_html(content) - - # Event handlers - def validation_chain(content: str) -> Tuple[str, List[str]]: - """Combined validation and card choice update.""" - validation_message = validate_for_ui(content) - card_choices = update_card_choices(content) - return validation_message, card_choices - - def delete_card(card_selection, current_content): - """Delete selected card and return updated content.""" - if not card_selection or not current_content: - return current_content, "No card selected", [] - - try: - data = json.loads(current_content) - selected_id = card_selection.split(" - ")[0] - - data['cards'] = [card for card in data['cards'] if card['id'] != selected_id] - new_content = json.dumps(data, indent=2) - - return ( - new_content, - "Card deleted successfully!", - generate_card_choices(new_content) - ) - - except Exception as e: - return current_content, f"Error deleting card: {str(e)}", [] - - def process_validation_result(is_valid, message): - """Process validation result into a formatted markdown string.""" - if is_valid: - return f"✅ {message}" - else: - return f"❌ {message}" - - # Register event handlers - input_type.change( - fn=lambda t: ( - gr.update(visible=t == "JSON"), - gr.update(visible=t == "APKG"), - gr.update(visible=t == "APKG") - ), - inputs=[input_type], - outputs=[json_input_group, apkg_input_group, deck_info] - ) - - # File upload handlers - import_json.upload( - fn=handle_file_upload, - inputs=[import_json, input_type], - outputs=[ - flashcard_input, - deck_info, - validation_status, - card_selector - ] - ) - - import_apkg.upload( - fn=enhanced_file_upload, - inputs=[import_apkg, input_type], - outputs=[ - flashcard_input, - deck_info, - validation_status, - card_selector - ] - ) - - # Validation handler - validate_button.click( - fn=lambda content, input_format: ( - handle_validation(content, input_format), - generate_card_choices(content) if content else [] - ), - inputs=[flashcard_input, input_type], - outputs=[validation_status, card_selector] - ) - - # Card editing handlers - # Card selector change event - card_selector.change( - fn=load_card_for_editing, - inputs=[card_selector, flashcard_input], - outputs=[ - card_type, - front_content, - back_content, - tags_input, - notes_input, - front_preview, - back_preview - ] - ) - - # Live preview updates - front_content.change( - fn=update_preview, - inputs=[front_content], - outputs=[front_preview] - ) - - back_content.change( - fn=update_preview, - inputs=[back_content], - outputs=[back_preview] - ) - - # Card update handler - update_card_button.click( - fn=update_card_with_validation, - inputs=[ - flashcard_input, - card_selector, - card_type, - front_content, - back_content, - tags_input, - notes_input - ], - outputs=[ - flashcard_input, - validation_status, - card_selector - ] - ) - - # Delete card handler - delete_card_button.click( - fn=delete_card, - inputs=[card_selector, flashcard_input], - outputs=[flashcard_input, validation_status, card_selector] - ) - - # Export handler - export_button.click( - fn=export_cards, - inputs=[flashcard_input, export_format], - outputs=[export_status, export_file] - ) - - return ( - flashcard_input, - import_json, - import_apkg, - validate_button, - validation_status, - card_selector, - card_type, - front_content, - back_content, - front_preview, - back_preview, - tags_input, - notes_input, - update_card_button, - delete_card_button, - export_format, - export_button, - export_file, - export_status, - deck_info - ) - -# -# End of Anki_Validation_tab.py -############################################################################################################ diff --git a/App_Function_Libraries/Gradio_UI/Anki_tab.py b/App_Function_Libraries/Gradio_UI/Anki_tab.py new file mode 100644 index 000000000..4f261f9b5 --- /dev/null +++ b/App_Function_Libraries/Gradio_UI/Anki_tab.py @@ -0,0 +1,921 @@ +# Anki_Validation_tab.py +# Description: Gradio functions for the Anki Validation tab +# +# Imports +import json +import logging +import os +import tempfile +from typing import Optional, Tuple, List, Dict +# +# External Imports +import genanki +import gradio as gr +# +# Local Imports +from App_Function_Libraries.Chat.Chat_Functions import approximate_token_count, update_chat_content, save_chat_history, \ + save_chat_history_to_db_wrapper +from App_Function_Libraries.DB.DB_Manager import list_prompts +from App_Function_Libraries.Gradio_UI.Chat_ui import update_dropdown_multiple, chat_wrapper, update_selected_parts, \ + search_conversations, regenerate_last_message, load_conversation, debug_output +from App_Function_Libraries.Third_Party.Anki import sanitize_html, generate_card_choices, \ + export_cards, load_card_for_editing, handle_file_upload, \ + validate_for_ui, update_card_with_validation, update_card_choices, enhanced_file_upload, \ + handle_validation +from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name +# +############################################################################################################ +# +# Functions: + +def create_anki_validation_tab(): + with gr.TabItem("Anki Flashcard Validation", visible=True): + gr.Markdown("# Anki Flashcard Validation and Editor") + + # State variables for internal tracking + current_card_data = gr.State({}) + preview_update_flag = gr.State(False) + + with gr.Row(): + # Left Column: Input and Validation + with gr.Column(scale=1): + gr.Markdown("## Import or Create Flashcards") + + input_type = gr.Radio( + choices=["JSON", "APKG"], + label="Input Type", + value="JSON" + ) + + with gr.Group() as json_input_group: + flashcard_input = gr.TextArea( + label="Enter Flashcards (JSON format)", + placeholder='''{ + "cards": [ + { + "id": "CARD_001", + "type": "basic", + "front": "What is the capital of France?", + "back": "Paris", + "tags": ["geography", "europe"], + "note": "Remember: City of Light" + } + ] +}''', + lines=10 + ) + + import_json = gr.File( + label="Or Import JSON File", + file_types=[".json"] + ) + + with gr.Group(visible=False) as apkg_input_group: + import_apkg = gr.File( + label="Import APKG File", + file_types=[".apkg"] + ) + deck_info = gr.JSON( + label="Deck Information", + visible=False + ) + + validate_button = gr.Button("Validate Flashcards") + + # Right Column: Validation Results and Editor + with gr.Column(scale=1): + gr.Markdown("## Validation Results") + validation_status = gr.Markdown("") + + with gr.Accordion("Validation Rules", open=False): + gr.Markdown(""" + ### Required Fields: + - Unique ID + - Card Type (basic, cloze, reverse) + - Front content + - Back content + - At least one tag + + ### Content Rules: + - No empty fields + - Front side should be a clear question/prompt + - Back side should contain complete answer + - Cloze deletions must have valid syntax + - No duplicate IDs + + ### Image Rules: + - Valid image tags + - Supported formats (JPG, PNG, GIF) + - Base64 encoded or valid URL + + ### APKG-specific Rules: + - Valid SQLite database structure + - Media files properly referenced + - Note types match Anki standards + - Card templates are well-formed + """) + + with gr.Row(): + # Card Editor + gr.Markdown("## Card Editor") + with gr.Row(): + with gr.Column(scale=1): + with gr.Accordion("Edit Individual Cards", open=True): + card_selector = gr.Dropdown( + label="Select Card to Edit", + choices=[], + interactive=True + ) + + card_type = gr.Radio( + choices=["basic", "cloze", "reverse"], + label="Card Type", + value="basic" + ) + + # Front content with preview + with gr.Group(): + gr.Markdown("### Front Content") + front_content = gr.TextArea( + label="Content (HTML supported)", + lines=3 + ) + front_preview = gr.HTML( + label="Preview" + ) + + # Back content with preview + with gr.Group(): + gr.Markdown("### Back Content") + back_content = gr.TextArea( + label="Content (HTML supported)", + lines=3 + ) + back_preview = gr.HTML( + label="Preview" + ) + + tags_input = gr.TextArea( + label="Tags (comma-separated)", + lines=1 + ) + + notes_input = gr.TextArea( + label="Additional Notes", + lines=2 + ) + + with gr.Row(): + update_card_button = gr.Button("Update Card") + delete_card_button = gr.Button("Delete Card", variant="stop") + + with gr.Row(): + with gr.Column(scale=1): + # Export Options + gr.Markdown("## Export Options") + export_format = gr.Radio( + choices=["Anki CSV", "JSON", "Plain Text"], + label="Export Format", + value="Anki CSV" + ) + export_button = gr.Button("Export Valid Cards") + export_file = gr.File(label="Download Validated Cards") + export_status = gr.Markdown("") + with gr.Column(scale=1): + gr.Markdown("## Export Instructions") + gr.Markdown(""" + ### Anki CSV Format: + - Front, Back, Tags, Type, Note + - Use for importing into Anki + - Images preserved as HTML + + ### JSON Format: + - JSON array of cards + - Images as base64 or URLs + - Use for custom processing + + ### Plain Text Format: + - Question and Answer pairs + - Images represented as [IMG] placeholder + - Use for manual review + """) + + def update_preview(content): + """Update preview with sanitized content.""" + if not content: + return "" + return sanitize_html(content) + + # Event handlers + def validation_chain(content: str) -> Tuple[str, List[str]]: + """Combined validation and card choice update.""" + validation_message = validate_for_ui(content) + card_choices = update_card_choices(content) + return validation_message, card_choices + + def delete_card(card_selection, current_content): + """Delete selected card and return updated content.""" + if not card_selection or not current_content: + return current_content, "No card selected", [] + + try: + data = json.loads(current_content) + selected_id = card_selection.split(" - ")[0] + + data['cards'] = [card for card in data['cards'] if card['id'] != selected_id] + new_content = json.dumps(data, indent=2) + + return ( + new_content, + "Card deleted successfully!", + generate_card_choices(new_content) + ) + + except Exception as e: + return current_content, f"Error deleting card: {str(e)}", [] + + def process_validation_result(is_valid, message): + """Process validation result into a formatted markdown string.""" + if is_valid: + return f"✅ {message}" + else: + return f"❌ {message}" + + # Register event handlers + input_type.change( + fn=lambda t: ( + gr.update(visible=t == "JSON"), + gr.update(visible=t == "APKG"), + gr.update(visible=t == "APKG") + ), + inputs=[input_type], + outputs=[json_input_group, apkg_input_group, deck_info] + ) + + # File upload handlers + import_json.upload( + fn=handle_file_upload, + inputs=[import_json, input_type], + outputs=[ + flashcard_input, + deck_info, + validation_status, + card_selector + ] + ) + + import_apkg.upload( + fn=enhanced_file_upload, + inputs=[import_apkg, input_type], + outputs=[ + flashcard_input, + deck_info, + validation_status, + card_selector + ] + ) + + # Validation handler + validate_button.click( + fn=lambda content, input_format: ( + handle_validation(content, input_format), + generate_card_choices(content) if content else [] + ), + inputs=[flashcard_input, input_type], + outputs=[validation_status, card_selector] + ) + + # Card editing handlers + # Card selector change event + card_selector.change( + fn=load_card_for_editing, + inputs=[card_selector, flashcard_input], + outputs=[ + card_type, + front_content, + back_content, + tags_input, + notes_input, + front_preview, + back_preview + ] + ) + + # Live preview updates + front_content.change( + fn=update_preview, + inputs=[front_content], + outputs=[front_preview] + ) + + back_content.change( + fn=update_preview, + inputs=[back_content], + outputs=[back_preview] + ) + + # Card update handler + update_card_button.click( + fn=update_card_with_validation, + inputs=[ + flashcard_input, + card_selector, + card_type, + front_content, + back_content, + tags_input, + notes_input + ], + outputs=[ + flashcard_input, + validation_status, + card_selector + ] + ) + + # Delete card handler + delete_card_button.click( + fn=delete_card, + inputs=[card_selector, flashcard_input], + outputs=[flashcard_input, validation_status, card_selector] + ) + + # Export handler + export_button.click( + fn=export_cards, + inputs=[flashcard_input, export_format], + outputs=[export_status, export_file] + ) + + return ( + flashcard_input, + import_json, + import_apkg, + validate_button, + validation_status, + card_selector, + card_type, + front_content, + back_content, + front_preview, + back_preview, + tags_input, + notes_input, + update_card_button, + delete_card_button, + export_format, + export_button, + export_file, + export_status, + deck_info + ) + + +def create_anki_generator_tab(): + with gr.TabItem("Anki Deck Generator", visible=True): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None + custom_css = """ + .chatbot-container .message-wrap .message { + font-size: 14px !important; + } + """ + with gr.TabItem("LLM Chat & Anki Deck Creation", visible=True): + gr.Markdown("# Chat with an LLM to help you come up with Questions/Answers for an Anki Deck") + chat_history = gr.State([]) + media_content = gr.State({}) + selected_parts = gr.State([]) + conversation_id = gr.State(None) + initial_prompts, total_pages, current_page = list_prompts(page=1, per_page=10) + + with gr.Row(): + with gr.Column(scale=1): + search_query_input = gr.Textbox( + label="Search Query", + placeholder="Enter your search query here..." + ) + search_type_input = gr.Radio( + choices=["Title", "Content", "Author", "Keyword"], + value="Keyword", + label="Search By" + ) + keyword_filter_input = gr.Textbox( + label="Filter by Keywords (comma-separated)", + placeholder="ml, ai, python, etc..." + ) + search_button = gr.Button("Search") + items_output = gr.Dropdown(label="Select Item", choices=[], interactive=True) + item_mapping = gr.State({}) + with gr.Row(): + use_content = gr.Checkbox(label="Use Content") + use_summary = gr.Checkbox(label="Use Summary") + use_prompt = gr.Checkbox(label="Use Prompt") + save_conversation = gr.Checkbox(label="Save Conversation", value=False, visible=True) + with gr.Row(): + temperature = gr.Slider(label="Temperature", minimum=0.00, maximum=1.0, step=0.05, value=0.7) + with gr.Row(): + conversation_search = gr.Textbox(label="Search Conversations") + with gr.Row(): + search_conversations_btn = gr.Button("Search Conversations") + with gr.Row(): + previous_conversations = gr.Dropdown(label="Select Conversation", choices=[], interactive=True) + with gr.Row(): + load_conversations_btn = gr.Button("Load Selected Conversation") + + # Refactored API selection dropdown + api_endpoint = gr.Dropdown( + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Chat Interaction (Optional)" + ) + api_key = gr.Textbox(label="API Key (if required)", type="password") + custom_prompt_checkbox = gr.Checkbox(label="Use a Custom Prompt", + value=False, + visible=True) + preset_prompt_checkbox = gr.Checkbox(label="Use a Pre-set Prompt", + value=False, + visible=True) + with gr.Row(visible=False) as preset_prompt_controls: + prev_prompt_page = gr.Button("Previous") + next_prompt_page = gr.Button("Next") + current_prompt_page_text = gr.Text(f"Page {current_page} of {total_pages}") + current_prompt_page_state = gr.State(value=1) + + preset_prompt = gr.Dropdown( + label="Select Preset Prompt", + choices=initial_prompts + ) + user_prompt = gr.Textbox(label="Custom Prompt", + placeholder="Enter custom prompt here", + lines=3, + visible=False) + system_prompt_input = gr.Textbox(label="System Prompt", + value="You are a helpful AI assitant", + lines=3, + visible=False) + with gr.Column(scale=2): + chatbot = gr.Chatbot(height=800, elem_classes="chatbot-container") + msg = gr.Textbox(label="Enter your message") + submit = gr.Button("Submit") + regenerate_button = gr.Button("Regenerate Last Message") + token_count_display = gr.Number(label="Approximate Token Count", value=0, interactive=False) + clear_chat_button = gr.Button("Clear Chat") + + chat_media_name = gr.Textbox(label="Custom Chat Name(optional)") + save_chat_history_to_db = gr.Button("Save Chat History to DataBase") + save_status = gr.Textbox(label="Save Status", interactive=False) + save_chat_history_as_file = gr.Button("Save Chat History as File") + download_file = gr.File(label="Download Chat History") + + search_button.click( + fn=update_dropdown_multiple, + inputs=[search_query_input, search_type_input, keyword_filter_input], + outputs=[items_output, item_mapping] + ) + + def update_prompt_visibility(custom_prompt_checked, preset_prompt_checked): + user_prompt_visible = custom_prompt_checked + system_prompt_visible = custom_prompt_checked + preset_prompt_visible = preset_prompt_checked + preset_prompt_controls_visible = preset_prompt_checked + return ( + gr.update(visible=user_prompt_visible, interactive=user_prompt_visible), + gr.update(visible=system_prompt_visible, interactive=system_prompt_visible), + gr.update(visible=preset_prompt_visible, interactive=preset_prompt_visible), + gr.update(visible=preset_prompt_controls_visible) + ) + + def update_prompt_page(direction, current_page_val): + new_page = current_page_val + direction + if new_page < 1: + new_page = 1 + prompts, total_pages, _ = list_prompts(page=new_page, per_page=20) + if new_page > total_pages: + new_page = total_pages + prompts, total_pages, _ = list_prompts(page=new_page, per_page=20) + return ( + gr.update(choices=prompts), + gr.update(value=f"Page {new_page} of {total_pages}"), + new_page + ) + + def clear_chat(): + return [], None # Return empty list for chatbot and None for conversation_id + + custom_prompt_checkbox.change( + update_prompt_visibility, + inputs=[custom_prompt_checkbox, preset_prompt_checkbox], + outputs=[user_prompt, system_prompt_input, preset_prompt, preset_prompt_controls] + ) + + preset_prompt_checkbox.change( + update_prompt_visibility, + inputs=[custom_prompt_checkbox, preset_prompt_checkbox], + outputs=[user_prompt, system_prompt_input, preset_prompt, preset_prompt_controls] + ) + + prev_prompt_page.click( + lambda x: update_prompt_page(-1, x), + inputs=[current_prompt_page_state], + outputs=[preset_prompt, current_prompt_page_text, current_prompt_page_state] + ) + + next_prompt_page.click( + lambda x: update_prompt_page(1, x), + inputs=[current_prompt_page_state], + outputs=[preset_prompt, current_prompt_page_text, current_prompt_page_state] + ) + + submit.click( + chat_wrapper, + inputs=[msg, chatbot, media_content, selected_parts, api_endpoint, api_key, user_prompt, + conversation_id, + save_conversation, temperature, system_prompt_input], + outputs=[msg, chatbot, conversation_id] + ).then( # Clear the message box after submission + lambda x: gr.update(value=""), + inputs=[chatbot], + outputs=[msg] + ).then( # Clear the user prompt after the first message + lambda: (gr.update(value=""), gr.update(value="")), + outputs=[user_prompt, system_prompt_input] + ).then( + lambda history: approximate_token_count(history), + inputs=[chatbot], + outputs=[token_count_display] + ) + + + clear_chat_button.click( + clear_chat, + outputs=[chatbot, conversation_id] + ) + + items_output.change( + update_chat_content, + inputs=[items_output, use_content, use_summary, use_prompt, item_mapping], + outputs=[media_content, selected_parts] + ) + + use_content.change(update_selected_parts, inputs=[use_content, use_summary, use_prompt], + outputs=[selected_parts]) + use_summary.change(update_selected_parts, inputs=[use_content, use_summary, use_prompt], + outputs=[selected_parts]) + use_prompt.change(update_selected_parts, inputs=[use_content, use_summary, use_prompt], + outputs=[selected_parts]) + items_output.change(debug_output, inputs=[media_content, selected_parts], outputs=[]) + + search_conversations_btn.click( + search_conversations, + inputs=[conversation_search], + outputs=[previous_conversations] + ) + + load_conversations_btn.click( + clear_chat, + outputs=[chatbot, chat_history] + ).then( + load_conversation, + inputs=[previous_conversations], + outputs=[chatbot, conversation_id] + ) + + previous_conversations.change( + load_conversation, + inputs=[previous_conversations], + outputs=[chat_history] + ) + + save_chat_history_as_file.click( + save_chat_history, + inputs=[chatbot, conversation_id], + outputs=[download_file] + ) + + save_chat_history_to_db.click( + save_chat_history_to_db_wrapper, + inputs=[chatbot, conversation_id, media_content, chat_media_name], + outputs=[conversation_id, gr.Textbox(label="Save Status")] + ) + + regenerate_button.click( + regenerate_last_message, + inputs=[chatbot, media_content, selected_parts, api_endpoint, api_key, user_prompt, temperature, + system_prompt_input], + outputs=[chatbot, save_status] + ).then( + lambda history: approximate_token_count(history), + inputs=[chatbot], + outputs=[token_count_display] + ) + gr.Markdown("# Create Anki Deck") + + with gr.Row(): + # Left Column: Deck Settings + with gr.Column(scale=1): + gr.Markdown("## Deck Settings") + deck_name = gr.Textbox( + label="Deck Name", + placeholder="My Study Deck", + value="My Study Deck" + ) + + deck_description = gr.Textbox( + label="Deck Description", + placeholder="Description of your deck", + lines=2 + ) + + note_type = gr.Radio( + choices=["Basic", "Basic (and reversed)", "Cloze"], + label="Note Type", + value="Basic" + ) + + # Card Fields based on note type + with gr.Group() as basic_fields: + front_template = gr.Textbox( + label="Front Template (HTML)", + value="{{Front}}", + lines=3 + ) + back_template = gr.Textbox( + label="Back Template (HTML)", + value="{{FrontSide}}
No conversation selected
" - except Exception as e: - logging.error(f"Error in load_conversations: {str(e)}") - return f"Error: {str(e)}", "Error loading conversation
" - - def validate_conversation_json(content): - try: - data = json.loads(content) - if not isinstance(data, dict): - return False, "Invalid JSON structure: root should be an object" - if "conversation_id" not in data or not isinstance(data["conversation_id"], int): - return False, "Missing or invalid conversation_id" - if "messages" not in data or not isinstance(data["messages"], list): - return False, "Missing or invalid messages array" - for msg in data["messages"]: - if not all(key in msg for key in ["sender", "message"]): - return False, "Invalid message structure: missing required fields" - return True, data - except json.JSONDecodeError as e: - return False, f"Invalid JSON: {str(e)}" - - def save_conversation(selected, conversation_mapping, content): - if not selected or selected not in conversation_mapping: - return "Please select a conversation before saving.", "No changes made
" - - conversation_id = conversation_mapping[selected] - is_valid, result = validate_conversation_json(content) - - if not is_valid: - return f"Error: {result}", "No changes made due to error
" - - conversation_data = result - if conversation_data["conversation_id"] != conversation_id: - return "Error: Conversation ID mismatch.", "No changes made due to ID mismatch
" - - try: - with db.get_connection() as conn: - conn.execute("BEGIN TRANSACTION") - cursor = conn.cursor() - - # Backup original conversation - cursor.execute("SELECT * FROM ChatMessages WHERE conversation_id = ?", (conversation_id,)) - original_messages = cursor.fetchall() - backup_data = json.dumps({"conversation_id": conversation_id, "messages": original_messages}) - - # You might want to save this backup_data somewhere - - # Delete existing messages - cursor.execute("DELETE FROM ChatMessages WHERE conversation_id = ?", (conversation_id,)) - - # Insert updated messages - for message in conversation_data["messages"]: - cursor.execute(''' - INSERT INTO ChatMessages (conversation_id, sender, message, timestamp) - VALUES (?, ?, ?, COALESCE(?, CURRENT_TIMESTAMP)) - ''', (conversation_id, message["sender"], message["message"], message.get("timestamp"))) - - conn.commit() - - # Create updated HTML preview - html_preview = "Error occurred while saving
" - except Exception as e: - conn.rollback() - logging.error(f"Unexpected error in save_conversation: {e}") - return f"Unexpected error: {str(e)}", "Unexpected error occurred
" - - def delete_conversation(selected, conversation_mapping): - if not selected or selected not in conversation_mapping: - return "Please select a conversation before deleting.", "No changes made
", gr.update(choices=[]) - - conversation_id = conversation_mapping[selected] - - try: - with db.get_connection() as conn: - cursor = conn.cursor() - - # Delete messages associated with the conversation - cursor.execute("DELETE FROM ChatMessages WHERE conversation_id = ?", (conversation_id,)) - - # Delete the conversation itself - cursor.execute("DELETE FROM ChatConversations WHERE id = ?", (conversation_id,)) - - conn.commit() - - # Update the conversation list - remaining_conversations = [choice for choice in conversation_mapping.keys() if choice != selected] - updated_mapping = {choice: conversation_mapping[choice] for choice in remaining_conversations} - - return "Conversation deleted successfully.", "Conversation deleted
", gr.update(choices=remaining_conversations) - except sqlite3.Error as e: - conn.rollback() - logging.error(f"Database error in delete_conversation: {e}") - return f"Error deleting conversation: {str(e)}", "Error occurred while deleting
", gr.update() - except Exception as e: - conn.rollback() - logging.error(f"Unexpected error in delete_conversation: {e}") - return f"Unexpected error: {str(e)}", "Unexpected error occurred
", gr.update() - - def parse_formatted_content(formatted_content): - lines = formatted_content.split('\n') - conversation_id = int(lines[0].split(': ')[1]) - timestamp = lines[1].split(': ')[1] - history = [] - current_role = None - current_content = None - for line in lines[3:]: - if line.startswith("Role: "): - if current_role is not None: - history.append({"role": current_role, "content": ["", current_content]}) - current_role = line.split(': ')[1] - elif line.startswith("Content: "): - current_content = line.split(': ', 1)[1] - if current_role is not None: - history.append({"role": current_role, "content": ["", current_content]}) - return json.dumps({ - "conversation_id": conversation_id, - "timestamp": timestamp, - "history": history - }, indent=2) - - search_button.click( - search_conversations, - inputs=[search_query], - outputs=[conversation_list, conversation_mapping] - ) - - conversation_list.change( - load_conversations, - inputs=[conversation_list, conversation_mapping], - outputs=[chat_content, chat_preview] - ) - - save_button.click( - save_conversation, - inputs=[conversation_list, conversation_mapping, chat_content], - outputs=[result_message, chat_preview] - ) - - delete_button.click( - delete_conversation, - inputs=[conversation_list, conversation_mapping], - outputs=[result_message, chat_preview, conversation_list] - ) - - return search_query, search_button, conversation_list, conversation_mapping, chat_content, save_button, delete_button, result_message, chat_preview - - - # Mock function to simulate LLM processing def process_with_llm(workflow, context, prompt, api_endpoint, api_key): api_key_snippet = api_key[:5] + "..." if api_key else "Not provided" return f"LLM output using {api_endpoint} (API Key: {api_key_snippet}) for {workflow} with context: {context[:30]}... and prompt: {prompt[:30]}..." - # # End of Chat_ui.py ####################################################################################################################### \ No newline at end of file diff --git a/App_Function_Libraries/Gradio_UI/Embeddings_tab.py b/App_Function_Libraries/Gradio_UI/Embeddings_tab.py index a12e10320..fb6482ee2 100644 --- a/App_Function_Libraries/Gradio_UI/Embeddings_tab.py +++ b/App_Function_Libraries/Gradio_UI/Embeddings_tab.py @@ -4,6 +4,7 @@ # Imports import json import logging +import os # # External Imports import gradio as gr @@ -11,26 +12,58 @@ from tqdm import tqdm # # Local Imports -from App_Function_Libraries.DB.DB_Manager import get_all_content_from_database +from App_Function_Libraries.DB.DB_Manager import get_all_content_from_database, get_all_conversations, \ + get_conversation_text, get_note_by_id +from App_Function_Libraries.DB.RAG_QA_Chat_DB import get_all_notes from App_Function_Libraries.RAG.ChromaDB_Library import chroma_client, \ store_in_chroma, situate_context from App_Function_Libraries.RAG.Embeddings_Create import create_embedding, create_embeddings_batch from App_Function_Libraries.Chunk_Lib import improved_chunking_process, chunk_for_embedding +from App_Function_Libraries.Utils.Utils import load_and_log_configs + + # ######################################################################################################################## # # Functions: def create_embeddings_tab(): + # Load configuration first + config = load_and_log_configs() + if not config: + raise ValueError("Could not load configuration") + + # Get database paths from config + db_config = config['db_config'] + media_db_path = db_config['sqlite_path'] + rag_qa_db_path = os.path.join(os.path.dirname(media_db_path), "rag_qa.db") + character_chat_db_path = os.path.join(os.path.dirname(media_db_path), "chatDB.db") + chroma_db_path = db_config['chroma_db_path'] + with gr.TabItem("Create Embeddings", visible=True): gr.Markdown("# Create Embeddings for All Content") with gr.Row(): with gr.Column(): + # Database selection at the top + database_selection = gr.Radio( + choices=["Media DB", "RAG Chat", "Character Chat"], + label="Select Content Source", + value="Media DB", + info="Choose which database to create embeddings from" + ) + + # Add database path display + current_db_path = gr.Textbox( + label="Current Database Path", + value=media_db_path, + interactive=False + ) + embedding_provider = gr.Radio( choices=["huggingface", "local", "openai"], label="Select Embedding Provider", - value="huggingface" + value=config['embedding_config']['embedding_provider'] or "huggingface" ) gr.Markdown("Note: Local provider requires a running Llama.cpp/llamafile server.") gr.Markdown("OpenAI provider requires a valid API key.") @@ -65,22 +98,24 @@ def create_embeddings_tab(): embedding_api_url = gr.Textbox( label="API URL (for local provider)", - value="http://localhost:8080/embedding", + value=config['embedding_config']['embedding_api_url'], visible=False ) - # Add chunking options + # Add chunking options with config defaults chunking_method = gr.Dropdown( choices=["words", "sentences", "paragraphs", "tokens", "semantic"], label="Chunking Method", value="words" ) max_chunk_size = gr.Slider( - minimum=1, maximum=8000, step=1, value=500, + minimum=1, maximum=8000, step=1, + value=config['embedding_config']['chunk_size'], label="Max Chunk Size" ) chunk_overlap = gr.Slider( - minimum=0, maximum=4000, step=1, value=200, + minimum=0, maximum=4000, step=1, + value=config['embedding_config']['overlap'], label="Chunk Overlap" ) adaptive_chunking = gr.Checkbox( @@ -92,6 +127,7 @@ def create_embeddings_tab(): with gr.Column(): status_output = gr.Textbox(label="Status", lines=10) + progress = gr.Progress() def update_provider_options(provider): if provider == "huggingface": @@ -107,23 +143,54 @@ def update_huggingface_options(model): else: return gr.update(visible=False) - embedding_provider.change( - fn=update_provider_options, - inputs=[embedding_provider], - outputs=[huggingface_model, openai_model, custom_embedding_model, embedding_api_url] - ) - - huggingface_model.change( - fn=update_huggingface_options, - inputs=[huggingface_model], - outputs=[custom_embedding_model] - ) + def update_database_path(database_type): + if database_type == "Media DB": + return media_db_path + elif database_type == "RAG Chat": + return rag_qa_db_path + else: # Character Chat + return character_chat_db_path - def create_all_embeddings(provider, hf_model, openai_model, custom_model, api_url, method, max_size, overlap, adaptive): + def create_all_embeddings(provider, hf_model, openai_model, custom_model, api_url, method, + max_size, overlap, adaptive, database_type, progress=gr.Progress()): try: - all_content = get_all_content_from_database() + # Initialize content based on database selection + if database_type == "Media DB": + all_content = get_all_content_from_database() + content_type = "media" + elif database_type == "RAG Chat": + all_content = [] + page = 1 + while True: + conversations, total_pages, _ = get_all_conversations(page=page) + if not conversations: + break + all_content.extend([{ + 'id': conv['conversation_id'], + 'content': get_conversation_text(conv['conversation_id']), + 'title': conv['title'], + 'type': 'conversation' + } for conv in conversations]) + progress(page / total_pages, desc=f"Loading conversations... Page {page}/{total_pages}") + page += 1 + else: # Character Chat + all_content = [] + page = 1 + while True: + notes, total_pages, _ = get_all_notes(page=page) + if not notes: + break + all_content.extend([{ + 'id': note['id'], + 'content': f"{note['title']}\n\n{note['content']}", + 'conversation_id': note['conversation_id'], + 'type': 'note' + } for note in notes]) + progress(page / total_pages, desc=f"Loading notes... Page {page}/{total_pages}") + page += 1 + if not all_content: - return "No content found in the database." + return "No content found in the selected database." chunk_options = { 'method': method, @@ -132,7 +199,7 @@ def create_all_embeddings(provider, hf_model, openai_model, custom_model, api_ur 'adaptive': adaptive } - collection_name = "all_content_embeddings" + collection_name = f"{database_type.lower().replace(' ', '_')}_embeddings" collection = chroma_client.get_or_create_collection(name=collection_name) # Determine the model to use @@ -141,55 +208,113 @@ def create_all_embeddings(provider, hf_model, openai_model, custom_model, api_ur elif provider == "openai": model = openai_model else: - model = custom_model + model = api_url + + total_items = len(all_content) + for idx, item in enumerate(all_content): + progress((idx + 1) / total_items, desc=f"Processing item {idx + 1} of {total_items}") - for item in all_content: - media_id = item['id'] + content_id = item['id'] text = item['content'] chunks = improved_chunking_process(text, chunk_options) - for i, chunk in enumerate(chunks): + for chunk_idx, chunk in enumerate(chunks): chunk_text = chunk['text'] - chunk_id = f"doc_{media_id}_chunk_{i}" - - existing = collection.get(ids=[chunk_id]) - if existing['ids']: + chunk_id = f"{database_type.lower()}_{content_id}_chunk_{chunk_idx}" + + try: + embedding = create_embedding(chunk_text, provider, model, api_url) + metadata = { + 'content_id': str(content_id), + 'chunk_index': int(chunk_idx), + 'total_chunks': int(len(chunks)), + 'chunking_method': method, + 'max_chunk_size': int(max_size), + 'chunk_overlap': int(overlap), + 'adaptive_chunking': bool(adaptive), + 'embedding_model': model, + 'embedding_provider': provider, + 'content_type': item.get('type', 'media'), + 'conversation_id': item.get('conversation_id'), + **{k: (int(v) if isinstance(v, str) and v.isdigit() else v) + for k, v in chunk['metadata'].items()} + } + store_in_chroma(collection_name, [chunk_text], [embedding], [chunk_id], [metadata]) + + except Exception as e: + logging.error(f"Error processing chunk {chunk_id}: {str(e)}") continue - embedding = create_embedding(chunk_text, provider, model, api_url) - metadata = { - "media_id": str(media_id), - "chunk_index": i, - "total_chunks": len(chunks), - "chunking_method": method, - "max_chunk_size": max_size, - "chunk_overlap": overlap, - "adaptive_chunking": adaptive, - "embedding_model": model, - "embedding_provider": provider, - **chunk['metadata'] - } - store_in_chroma(collection_name, [chunk_text], [embedding], [chunk_id], [metadata]) - - return "Embeddings created and stored successfully for all content." + return f"Embeddings created and stored successfully for all {database_type} content." except Exception as e: logging.error(f"Error during embedding creation: {str(e)}") return f"Error: {str(e)}" + # Event handlers + embedding_provider.change( + fn=update_provider_options, + inputs=[embedding_provider], + outputs=[huggingface_model, openai_model, custom_embedding_model, embedding_api_url] + ) + + huggingface_model.change( + fn=update_huggingface_options, + inputs=[huggingface_model], + outputs=[custom_embedding_model] + ) + + database_selection.change( + fn=update_database_path, + inputs=[database_selection], + outputs=[current_db_path] + ) + create_button.click( fn=create_all_embeddings, - inputs=[embedding_provider, huggingface_model, openai_model, custom_embedding_model, embedding_api_url, - chunking_method, max_chunk_size, chunk_overlap, adaptive_chunking], + inputs=[ + embedding_provider, huggingface_model, openai_model, custom_embedding_model, + embedding_api_url, chunking_method, max_chunk_size, chunk_overlap, + adaptive_chunking, database_selection + ], outputs=status_output ) def create_view_embeddings_tab(): + # Load configuration first + config = load_and_log_configs() + if not config: + raise ValueError("Could not load configuration") + + # Get database paths from config + db_config = config['db_config'] + media_db_path = db_config['sqlite_path'] + rag_qa_db_path = os.path.join(os.path.dirname(media_db_path), "rag_chat.db") + character_chat_db_path = os.path.join(os.path.dirname(media_db_path), "character_chat.db") + chroma_db_path = db_config['chroma_db_path'] + with gr.TabItem("View/Update Embeddings", visible=True): gr.Markdown("# View and Update Embeddings") - item_mapping = gr.State({}) + # Initialize item_mapping as a Gradio State + + with gr.Row(): with gr.Column(): + # Add database selection + database_selection = gr.Radio( + choices=["Media DB", "RAG Chat", "Character Chat"], + label="Select Content Source", + value="Media DB", + info="Choose which database to view embeddings from" + ) + + # Add database path display + current_db_path = gr.Textbox( + label="Current Database Path", + value=media_db_path, + interactive=False + ) + item_dropdown = gr.Dropdown(label="Select Item", choices=[], interactive=True) refresh_button = gr.Button("Refresh Item List") embedding_status = gr.Textbox(label="Embedding Status", interactive=False) @@ -236,9 +361,10 @@ def create_view_embeddings_tab(): embedding_api_url = gr.Textbox( label="API URL (for local provider)", - value="http://localhost:8080/embedding", + value=config['embedding_config']['embedding_api_url'], visible=False ) + chunking_method = gr.Dropdown( choices=["words", "sentences", "paragraphs", "tokens", "semantic"], label="Chunking Method", @@ -267,15 +393,45 @@ def create_view_embeddings_tab(): ) contextual_api_key = gr.Textbox(label="API Key", lines=1) - def get_items_with_embedding_status(): + item_mapping = gr.State(value={}) + + def update_database_path(database_type): + if database_type == "Media DB": + return media_db_path + elif database_type == "RAG Chat": + return rag_qa_db_path + else: # Character Chat + return character_chat_db_path + + def get_items_with_embedding_status(database_type): try: - items = get_all_content_from_database() - collection = chroma_client.get_or_create_collection(name="all_content_embeddings") + # Get items based on database selection + if database_type == "Media DB": + items = get_all_content_from_database() + elif database_type == "RAG Chat": + conversations, _, _ = get_all_conversations(page=1) + items = [{ + 'id': conv['conversation_id'], + 'title': conv['title'], + 'type': 'conversation' + } for conv in conversations] + else: # Character Chat + notes, _, _ = get_all_notes(page=1) + items = [{ + 'id': note['id'], + 'title': note['title'], + 'type': 'note' + } for note in notes] + + collection_name = f"{database_type.lower().replace(' ', '_')}_embeddings" + collection = chroma_client.get_or_create_collection(name=collection_name) + choices = [] new_item_mapping = {} for item in items: try: - result = collection.get(ids=[f"doc_{item['id']}_chunk_0"]) + chunk_id = f"{database_type.lower()}_{item['id']}_chunk_0" + result = collection.get(ids=[chunk_id]) embedding_exists = result is not None and result.get('ids') and len(result['ids']) > 0 status = "Embedding exists" if embedding_exists else "No embedding" except Exception as e: @@ -303,40 +459,62 @@ def update_huggingface_options(model): else: return gr.update(visible=False) - def check_embedding_status(selected_item, item_mapping): + def check_embedding_status(selected_item, database_type, item_mapping): if not selected_item: return "Please select an item", "", "" + if item_mapping is None: + # If mapping is None, try to refresh it + try: + _, item_mapping = get_items_with_embedding_status(database_type) + except Exception as e: + return f"Error initializing item mapping: {str(e)}", "", "" + try: item_id = item_mapping.get(selected_item) if item_id is None: return f"Invalid item selected: {selected_item}", "", "" item_title = selected_item.rsplit(' (', 1)[0] - collection = chroma_client.get_or_create_collection(name="all_content_embeddings") + collection_name = f"{database_type.lower().replace(' ', '_')}_embeddings" + collection = chroma_client.get_or_create_collection(name=collection_name) + chunk_id = f"{database_type.lower()}_{item_id}_chunk_0" + + try: + result = collection.get(ids=[chunk_id], include=["embeddings", "metadatas"]) + except Exception as e: + logging.error(f"ChromaDB get error: {str(e)}") + return f"Error retrieving embedding for '{item_title}': {str(e)}", "", "" - result = collection.get(ids=[f"doc_{item_id}_chunk_0"], include=["embeddings", "metadatas"]) - logging.info(f"ChromaDB result for item '{item_title}' (ID: {item_id}): {result}") + # Check if result exists and has the expected structure + if not result or not isinstance(result, dict): + return f"No embedding found for item '{item_title}' (ID: {item_id})", "", "" - if not result['ids']: + # Check if we have any results + if not result.get('ids') or len(result['ids']) == 0: return f"No embedding found for item '{item_title}' (ID: {item_id})", "", "" - if not result['embeddings'] or not result['embeddings'][0]: + # Check if embeddings exist + if not result.get('embeddings') or not result['embeddings'][0]: return f"Embedding data missing for item '{item_title}' (ID: {item_id})", "", "" embedding = result['embeddings'][0] - metadata = result['metadatas'][0] if result['metadatas'] else {} + metadata = result.get('metadatas', [{}])[0] if result.get('metadatas') else {} embedding_preview = str(embedding[:50]) status = f"Embedding exists for item '{item_title}' (ID: {item_id})" return status, f"First 50 elements of embedding:\n{embedding_preview}", json.dumps(metadata, indent=2) except Exception as e: - logging.error(f"Error in check_embedding_status: {str(e)}") + logging.error(f"Error in check_embedding_status: {str(e)}", exc_info=True) return f"Error processing item: {selected_item}. Details: {str(e)}", "", "" - def create_new_embedding_for_item(selected_item, provider, hf_model, openai_model, custom_model, api_url, - method, max_size, overlap, adaptive, - item_mapping, use_contextual, contextual_api_choice=None): + def refresh_and_update(database_type): + choices_update, new_mapping = get_items_with_embedding_status(database_type) + return choices_update, new_mapping + + def create_new_embedding_for_item(selected_item, database_type, provider, hf_model, openai_model, + custom_model, api_url, method, max_size, overlap, adaptive, + item_mapping, use_contextual, contextual_api_choice=None): if not selected_item: return "Please select an item", "", "" @@ -345,8 +523,26 @@ def create_new_embedding_for_item(selected_item, provider, hf_model, openai_mode if item_id is None: return f"Invalid item selected: {selected_item}", "", "" - items = get_all_content_from_database() - item = next((item for item in items if item['id'] == item_id), None) + # Get item content based on database type + if database_type == "Media DB": + items = get_all_content_from_database() + item = next((item for item in items if item['id'] == item_id), None) + elif database_type == "RAG Chat": + item = { + 'id': item_id, + 'content': get_conversation_text(item_id), + 'title': selected_item.rsplit(' (', 1)[0], + 'type': 'conversation' + } + else: # Character Chat + note = get_note_by_id(item_id) + item = { + 'id': item_id, + 'content': f"{note['title']}\n\n{note['content']}", + 'title': note['title'], + 'type': 'note' + } + if not item: return f"Item not found: {item_id}", "", "" @@ -359,11 +555,11 @@ def create_new_embedding_for_item(selected_item, provider, hf_model, openai_mode logging.info(f"Chunking content for item: {item['title']} (ID: {item_id})") chunks = chunk_for_embedding(item['content'], item['title'], chunk_options) - collection_name = "all_content_embeddings" + collection_name = f"{database_type.lower().replace(' ', '_')}_embeddings" collection = chroma_client.get_or_create_collection(name=collection_name) # Delete existing embeddings for this item - existing_ids = [f"doc_{item_id}_chunk_{i}" for i in range(len(chunks))] + existing_ids = [f"{database_type.lower()}_{item_id}_chunk_{i}" for i in range(len(chunks))] collection.delete(ids=existing_ids) logging.info(f"Deleted {len(existing_ids)} existing embeddings for item {item_id}") @@ -381,7 +577,7 @@ def create_new_embedding_for_item(selected_item, provider, hf_model, openai_mode contextualized_text = chunk_text context = None - chunk_id = f"doc_{item_id}_chunk_{i}" + chunk_id = f"{database_type.lower()}_{item_id}_chunk_{i}" # Determine the model to use if provider == "huggingface": @@ -392,7 +588,7 @@ def create_new_embedding_for_item(selected_item, provider, hf_model, openai_mode model = custom_model metadata = { - "media_id": str(item_id), + "content_id": str(item_id), "chunk_index": i, "total_chunks": len(chunks), "chunking_method": method, @@ -441,15 +637,25 @@ def create_new_embedding_for_item(selected_item, provider, hf_model, openai_mode logging.error(f"Error in create_new_embedding_for_item: {str(e)}", exc_info=True) return f"Error creating embedding: {str(e)}", "", "" + # Wire up all the event handlers + database_selection.change( + update_database_path, + inputs=[database_selection], + outputs=[current_db_path] + ) + refresh_button.click( get_items_with_embedding_status, + inputs=[database_selection], outputs=[item_dropdown, item_mapping] ) + item_dropdown.change( check_embedding_status, - inputs=[item_dropdown, item_mapping], + inputs=[item_dropdown, database_selection, item_mapping], outputs=[embedding_status, embedding_preview, embedding_metadata] ) + create_new_embedding_button.click( create_new_embedding_for_item, inputs=[item_dropdown, embedding_provider, huggingface_model, openai_model, custom_embedding_model, embedding_api_url, @@ -469,9 +675,10 @@ def create_new_embedding_for_item(selected_item, provider, hf_model, openai_mode ) return (item_dropdown, refresh_button, embedding_status, embedding_preview, embedding_metadata, - create_new_embedding_button, embedding_provider, huggingface_model, openai_model, custom_embedding_model, embedding_api_url, - chunking_method, max_chunk_size, chunk_overlap, adaptive_chunking, - use_contextual_embeddings, contextual_api_choice, contextual_api_key) + create_new_embedding_button, embedding_provider, huggingface_model, openai_model, + custom_embedding_model, embedding_api_url, chunking_method, max_chunk_size, + chunk_overlap, adaptive_chunking, use_contextual_embeddings, + contextual_api_choice, contextual_api_key) def create_purge_embeddings_tab(): diff --git a/App_Function_Libraries/Gradio_UI/Explain_summarize_tab.py b/App_Function_Libraries/Gradio_UI/Explain_summarize_tab.py index fbf5505ac..50942fa91 100644 --- a/App_Function_Libraries/Gradio_UI/Explain_summarize_tab.py +++ b/App_Function_Libraries/Gradio_UI/Explain_summarize_tab.py @@ -7,7 +7,7 @@ # External Imports import gradio as gr -from App_Function_Libraries.DB.DB_Manager import load_preset_prompts +from App_Function_Libraries.DB.DB_Manager import list_prompts from App_Function_Libraries.Gradio_UI.Gradio_Shared import update_user_prompt # # Local Imports @@ -37,32 +37,52 @@ def create_summarize_explain_tab(): except Exception as e: logging.error(f"Error setting default API endpoint: {str(e)}") default_value = None + with gr.TabItem("Analyze Text", visible=True): gr.Markdown("# Analyze / Explain / Summarize Text without ingesting it into the DB") + + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + with gr.Row(): with gr.Column(): with gr.Row(): - text_to_work_input = gr.Textbox(label="Text to be Explained or Summarized", - placeholder="Enter the text you want explained or summarized here", - lines=20) + text_to_work_input = gr.Textbox( + label="Text to be Explained or Summarized", + placeholder="Enter the text you want explained or summarized here", + lines=20 + ) with gr.Row(): explanation_checkbox = gr.Checkbox(label="Explain Text", value=True) summarization_checkbox = gr.Checkbox(label="Summarize Text", value=True) - custom_prompt_checkbox = gr.Checkbox(label="Use a Custom Prompt", - value=False, - visible=True) - preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", - value=False, - visible=True) + custom_prompt_checkbox = gr.Checkbox( + label="Use a Custom Prompt", + value=False, + visible=True + ) + preset_prompt_checkbox = gr.Checkbox( + label="Use a pre-set Prompt", + value=False, + visible=True + ) with gr.Row(): - preset_prompt = gr.Dropdown(label="Select Preset Prompt", - choices=load_preset_prompts(), - visible=False) + # Add pagination controls + preset_prompt = gr.Dropdown( + label="Select Preset Prompt", + choices=[], + visible=False + ) + prev_page_button = gr.Button("Previous Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + next_page_button = gr.Button("Next Page", visible=False) with gr.Row(): - custom_prompt_input = gr.Textbox(label="Custom Prompt", - placeholder="Enter custom prompt here", - lines=3, - visible=False) + custom_prompt_input = gr.Textbox( + label="Custom Prompt", + placeholder="Enter custom prompt here", + lines=10, + visible=False + ) with gr.Row(): system_prompt_input = gr.Textbox(label="System Prompt", value="""No valid mindmap content provided.
" + + html = "" + + colors = ['#e6f3ff', '#f0f7ff', '#f5f5f5', '#fff0f0', '#f0fff0'] + + def create_node_html(node, level): + bg_color = colors[(level - 1) % len(colors)] + if node['children']: + children_html = ''.join(create_node_html(child, level + 1) for child in node['children']) + return f""" +Title | Author |
---|---|
{html.escape(title)} | {html.escape(author)} |
Error fetching prompts: {e}
", "Error", 0, [] + + # Function to update page content + def update_page(page, entries_per_page): + results, pagination, total_pages, prompt_choices = view_database(page, entries_per_page) + page = int(page) + next_disabled = page >= total_pages + prev_disabled = page <= 1 + return ( + results, + pagination, + page, + gr.update(visible=True, interactive=not prev_disabled), # previous_page_button + gr.update(visible=True, interactive=not next_disabled), # next_page_button + gr.update(choices=prompt_choices) + ) + + # Function to go to the next page + def go_to_next_page(current_page, entries_per_page): + next_page = int(current_page) + 1 + return update_page(next_page, entries_per_page) + + # Function to go to the previous page + def go_to_previous_page(current_page, entries_per_page): + previous_page = max(1, int(current_page) - 1) + return update_page(previous_page, entries_per_page) + + # Function to display selected prompt details + def display_selected_prompt(prompt_name): + details = fetch_prompt_details(prompt_name) + if details: + title, author, description, system_prompt, user_prompt, keywords = details + # Handle None values by converting them to empty strings + description = description or "" + system_prompt = system_prompt or "" + user_prompt = user_prompt or "" + author = author or "Unknown" + keywords = keywords or "" + + html_content = f""" +Description: {html.escape(description)}
+{html.escape(system_prompt)}+
{html.escape(user_prompt)}+
Keywords: {html.escape(keywords)}
+Prompt not found.
" + + # Event handlers + view_button.click( + fn=update_page, + inputs=[page_number, entries_per_page], + outputs=[results_table, pagination_info, page_number, previous_page_button, next_page_button, prompt_selector] + ) + + next_page_button.click( + fn=go_to_next_page, + inputs=[page_number, entries_per_page], + outputs=[results_table, pagination_info, page_number, previous_page_button, next_page_button, prompt_selector] + ) + + previous_page_button.click( + fn=go_to_previous_page, + inputs=[page_number, entries_per_page], + outputs=[results_table, pagination_info, page_number, previous_page_button, next_page_button, prompt_selector] + ) + + prompt_selector.change( + fn=display_selected_prompt, + inputs=[prompt_selector], + outputs=[selected_prompt_display] + ) + + + +def create_prompts_export_tab(): + """Creates a tab for exporting prompts database content with multiple format options""" + with gr.TabItem("Export Prompts", visible=True): + gr.Markdown("# Export Prompts Database Content") + + with gr.Row(): + with gr.Column(): + export_type = gr.Radio( + choices=["All Prompts", "Prompts by Keyword"], + label="Export Type", + value="All Prompts" + ) + + # Keyword selection for filtered export + with gr.Column(visible=False) as keyword_col: + keyword_input = gr.Textbox( + label="Enter Keywords (comma-separated)", + placeholder="Enter keywords to filter prompts..." + ) + + # Export format selection + export_format = gr.Radio( + choices=["CSV", "Markdown (ZIP)"], + label="Export Format", + value="CSV" + ) + + # Export options + include_options = gr.CheckboxGroup( + choices=[ + "Include System Prompts", + "Include User Prompts", + "Include Details", + "Include Author", + "Include Keywords" + ], + label="Export Options", + value=["Include Keywords", "Include Author"] + ) + + # Markdown-specific options (only visible when Markdown is selected) + with gr.Column(visible=False) as markdown_options_col: + markdown_template = gr.Radio( + choices=[ + "Basic Template", + "Detailed Template", + "Custom Template" + ], + label="Markdown Template", + value="Basic Template" + ) + custom_template = gr.Textbox( + label="Custom Template", + placeholder="Use {title}, {author}, {details}, {system}, {user}, {keywords} as placeholders", + visible=False + ) + + export_button = gr.Button("Export Prompts") + + with gr.Column(): + export_status = gr.Textbox(label="Export Status", interactive=False) + export_file = gr.File(label="Download Export") + + def update_ui_visibility(export_type, format_choice, template_choice): + """Update UI elements visibility based on selections""" + show_keywords = export_type == "Prompts by Keyword" + show_markdown_options = format_choice == "Markdown (ZIP)" + show_custom_template = template_choice == "Custom Template" and show_markdown_options + + return [ + gr.update(visible=show_keywords), # keyword_col + gr.update(visible=show_markdown_options), # markdown_options_col + gr.update(visible=show_custom_template) # custom_template + ] + + def handle_export(export_type, keywords, export_format, options, markdown_template, custom_template): + """Handle the export process based on selected options""" + try: + # Parse options + include_system = "Include System Prompts" in options + include_user = "Include User Prompts" in options + include_details = "Include Details" in options + include_author = "Include Author" in options + include_keywords = "Include Keywords" in options + + # Handle keyword filtering + keyword_list = None + if export_type == "Prompts by Keyword" and keywords: + keyword_list = [k.strip() for k in keywords.split(",") if k.strip()] + + # Get the appropriate template + template = None + if export_format == "Markdown (ZIP)": + if markdown_template == "Custom Template": + template = custom_template + else: + template = markdown_template + + # Perform export + from App_Function_Libraries.DB.Prompts_DB import export_prompts + status, file_path = export_prompts( + export_format=export_format.split()[0].lower(), # 'csv' or 'markdown' + filter_keywords=keyword_list, + include_system=include_system, + include_user=include_user, + include_details=include_details, + include_author=include_author, + include_keywords=include_keywords, + markdown_template=template + ) + + return status, file_path + + except Exception as e: + error_msg = f"Export failed: {str(e)}" + logging.error(error_msg) + return error_msg, None + + # Event handlers + export_type.change( + fn=lambda t, f, m: update_ui_visibility(t, f, m), + inputs=[export_type, export_format, markdown_template], + outputs=[keyword_col, markdown_options_col, custom_template] + ) + + export_format.change( + fn=lambda t, f, m: update_ui_visibility(t, f, m), + inputs=[export_type, export_format, markdown_template], + outputs=[keyword_col, markdown_options_col, custom_template] + ) + + markdown_template.change( + fn=lambda t, f, m: update_ui_visibility(t, f, m), + inputs=[export_type, export_format, markdown_template], + outputs=[keyword_col, markdown_options_col, custom_template] + ) + + export_button.click( + fn=handle_export, + inputs=[ + export_type, + keyword_input, + export_format, + include_options, + markdown_template, + custom_template + ], + outputs=[export_status, export_file] + ) + +# +# End of Prompts_tab.py +#################################################################################################### diff --git a/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py b/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py index 76d316e0e..0c88aaeb3 100644 --- a/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py +++ b/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py @@ -6,6 +6,7 @@ import logging import json import os +import re from datetime import datetime # # External Imports @@ -14,27 +15,21 @@ # # Local Imports from App_Function_Libraries.Books.Book_Ingestion_Lib import read_epub -from App_Function_Libraries.DB.DB_Manager import DatabaseError, get_paginated_files, add_media_with_keywords -from App_Function_Libraries.DB.RAG_QA_Chat_DB import ( - save_notes, - add_keywords_to_note, - start_new_conversation, - save_message, - search_conversations_by_keywords, - load_chat_history, - get_all_conversations, - get_note_by_id, - get_notes_by_keywords, - get_notes_by_keyword_collection, - update_note, - clear_keywords_from_note, get_notes, get_keywords_for_note, delete_conversation, delete_note, execute_query, - add_keywords_to_conversation, fetch_all_notes, fetch_all_conversations, fetch_conversations_by_ids, - fetch_notes_by_ids, -) +from App_Function_Libraries.DB.Character_Chat_DB import search_character_chat, search_character_cards +from App_Function_Libraries.DB.DB_Manager import DatabaseError, get_paginated_files, add_media_with_keywords, \ + get_all_conversations, get_note_by_id, get_notes_by_keywords, start_new_conversation, update_note, save_notes, \ + clear_keywords_from_note, add_keywords_to_note, load_chat_history, save_message, add_keywords_to_conversation, \ + get_keywords_for_note, delete_note, search_conversations_by_keywords, get_conversation_title, delete_conversation, \ + update_conversation_title, fetch_all_conversations, fetch_all_notes, fetch_conversations_by_ids, fetch_notes_by_ids, \ + search_media_db, search_notes_titles, list_prompts +from App_Function_Libraries.DB.RAG_QA_Chat_DB import get_notes, delete_messages_in_conversation, search_rag_notes, \ + search_rag_chat, get_conversation_rating, set_conversation_rating +from App_Function_Libraries.Gradio_UI.Gradio_Shared import update_user_prompt from App_Function_Libraries.PDF.PDF_Ingestion_Lib import extract_text_and_format_from_pdf from App_Function_Libraries.RAG.RAG_Library_2 import generate_answer, enhanced_rag_pipeline from App_Function_Libraries.RAG.RAG_QA_Chat import search_database, rag_qa_chat -from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name +from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name, \ + load_comprehensive_config # @@ -60,18 +55,53 @@ def create_rag_qa_chat_tab(): "page": 1, "context_source": "Entire Media Database", "conversation_messages": [], + "conversation_id": None }) note_state = gr.State({"note_id": None}) + def auto_save_conversation(message, response, state_value, auto_save_enabled): + """Automatically save the conversation if auto-save is enabled""" + try: + if not auto_save_enabled: + return state_value + + conversation_id = state_value.get("conversation_id") + if not conversation_id: + # Create new conversation with default title + title = "Auto-saved Conversation " + datetime.now().strftime("%Y-%m-%d %H:%M:%S") + conversation_id = start_new_conversation(title=title) + state_value = state_value.copy() + state_value["conversation_id"] = conversation_id + + # Save the messages + save_message(conversation_id, "user", message) + save_message(conversation_id, "assistant", response) + + return state_value + except Exception as e: + logging.error(f"Error in auto-save: {str(e)}") + return state_value + # Update the conversation list function def update_conversation_list(): conversations, total_pages, total_count = get_all_conversations() - choices = [f"{title} (ID: {conversation_id})" for conversation_id, title in conversations] + choices = [ + f"{conversation['title']} (ID: {conversation['conversation_id']}) - Rating: {conversation['rating'] or 'Not Rated'}" + for conversation in conversations + ] return choices with gr.Row(): with gr.Column(scale=1): + # FIXME - Offer the user to search 2+ databases at once + database_types = ["Media DB", "RAG Chat", "RAG Notes", "Character Chat", "Character Cards"] + db_choice = gr.CheckboxGroup( + label="Select Database(s)", + choices=database_types, + value=["Media DB"], + interactive=True + ) context_source = gr.Radio( ["All Files in the Database", "Search Database", "Upload File"], label="Context Source", @@ -84,19 +114,52 @@ def update_conversation_list(): next_page_btn = gr.Button("Next Page") page_info = gr.HTML("Page 1") top_k_input = gr.Number(value=10, label="Maximum amount of results to use (Default: 10)", minimum=1, maximum=50, step=1, precision=0, interactive=True) - keywords_input = gr.Textbox(label="Keywords (comma-separated) to filter results by)", visible=True) + keywords_input = gr.Textbox(label="Keywords (comma-separated) to filter results by)", value="rag_qa_default_keyword" ,visible=True) use_query_rewriting = gr.Checkbox(label="Use Query Rewriting", value=True) use_re_ranking = gr.Checkbox(label="Use Re-ranking", value=True) - # with gr.Row(): - # page_number = gr.Number(value=1, label="Page", precision=0) - # page_size = gr.Number(value=20, label="Items per page", precision=0) - # total_pages = gr.Number(label="Total Pages", interactive=False) + config = load_comprehensive_config() + auto_save_value = config.getboolean('auto-save', 'save_character_chats', fallback=False) + auto_save_checkbox = gr.Checkbox( + label="Save chats automatically", + value=auto_save_value, + info="When enabled, conversations will be saved automatically after each message" + ) + initial_prompts, total_pages, current_page = list_prompts(page=1, per_page=10) + + preset_prompt_checkbox = gr.Checkbox( + label="View Custom Prompts(have to copy/paste them)", + value=False, + visible=True + ) + + with gr.Row(visible=False) as preset_prompt_controls: + prev_prompt_page = gr.Button("Previous") + current_prompt_page_text = gr.Text(f"Page {current_page} of {total_pages}") + next_prompt_page = gr.Button("Next") + current_prompt_page_state = gr.State(value=1) + + preset_prompt = gr.Dropdown( + label="Select Preset Prompt", + choices=initial_prompts, + visible=False + ) + user_prompt = gr.Textbox( + label="Custom Prompt", + placeholder="Enter custom prompt here", + lines=3, + visible=False + ) + + system_prompt_input = gr.Textbox( + label="System Prompt", + lines=3, + visible=False + ) search_query = gr.Textbox(label="Search Query", visible=False) search_button = gr.Button("Search", visible=False) search_results = gr.Dropdown(label="Search Results", choices=[], visible=False) - # FIXME - Add pages for search results handling file_upload = gr.File( label="Upload File", visible=False, @@ -108,14 +171,23 @@ def update_conversation_list(): load_conversation = gr.Dropdown( label="Load Conversation", choices=update_conversation_list() - ) + ) new_conversation = gr.Button("New Conversation") save_conversation_button = gr.Button("Save Conversation") conversation_title = gr.Textbox( - label="Conversation Title", placeholder="Enter a title for the new conversation" + label="Conversation Title", + placeholder="Enter a title for the new conversation" ) keywords = gr.Textbox(label="Keywords (comma-separated)", visible=True) + # Add the rating display and input + rating_display = gr.Markdown(value="", visible=False) + rating_input = gr.Radio( + choices=["1", "2", "3"], + label="Rate this Conversation (1-3 stars)", + visible=False + ) + # Refactored API selection dropdown api_choice = gr.Dropdown( choices=["None"] + [format_api_name(api) for api in global_api_endpoints], @@ -143,6 +215,8 @@ def update_conversation_list(): clear_notes_btn = gr.Button("Clear Current Note text") new_note_btn = gr.Button("New Note") + # FIXME - Change from only keywords to generalized search + search_notes_title = gr.Textbox(label="Search Notes by Title") search_notes_by_keyword = gr.Textbox(label="Search Notes by Keyword") search_notes_button = gr.Button("Search Notes") note_results = gr.Dropdown(label="Notes", choices=[]) @@ -150,8 +224,58 @@ def update_conversation_list(): loading_indicator = gr.HTML("Loading...", visible=False) status_message = gr.HTML() + auto_save_status = gr.HTML() + + # Function Definitions + def update_prompt_page(direction, current_page_val): + new_page = max(1, min(total_pages, current_page_val + direction)) + prompts, _, _ = list_prompts(page=new_page, per_page=10) + return ( + gr.update(choices=prompts), + gr.update(value=f"Page {new_page} of {total_pages}"), + new_page + ) + + def update_prompts(preset_name): + prompts = update_user_prompt(preset_name) + return ( + gr.update(value=prompts["user_prompt"], visible=True), + gr.update(value=prompts["system_prompt"], visible=True) + ) + + def toggle_preset_prompt(checkbox_value): + return ( + gr.update(visible=checkbox_value), + gr.update(visible=checkbox_value), + gr.update(visible=False), + gr.update(visible=False) + ) + + prev_prompt_page.click( + lambda x: update_prompt_page(-1, x), + inputs=[current_prompt_page_state], + outputs=[preset_prompt, current_prompt_page_text, current_prompt_page_state] + ) + + next_prompt_page.click( + lambda x: update_prompt_page(1, x), + inputs=[current_prompt_page_state], + outputs=[preset_prompt, current_prompt_page_text, current_prompt_page_state] + ) + + preset_prompt.change( + update_prompts, + inputs=preset_prompt, + outputs=[user_prompt, system_prompt_input] + ) + + preset_prompt_checkbox.change( + toggle_preset_prompt, + inputs=[preset_prompt_checkbox], + outputs=[preset_prompt, preset_prompt_controls, user_prompt, system_prompt_input] + ) def update_state(state, **kwargs): new_state = state.copy() @@ -166,18 +290,28 @@ def create_new_note(): outputs=[note_title, notes, note_state] ) - def search_notes(keywords): + def search_notes(search_notes_title, keywords): if keywords: keywords_list = [kw.strip() for kw in keywords.split(',')] notes_data, total_pages, total_count = get_notes_by_keywords(keywords_list) - choices = [f"Note {note_id} ({timestamp})" for note_id, title, content, timestamp in notes_data] - return gr.update(choices=choices) + choices = [f"Note {note_id} - {title} ({timestamp})" for + note_id, title, content, timestamp, conversation_id in notes_data] + return gr.update(choices=choices, label=f"Found {total_count} notes") + elif search_notes_title: + notes_data, total_pages, total_count = search_notes_titles(search_notes_title) + choices = [f"Note {note_id} - {title} ({timestamp})" for + note_id, title, content, timestamp, conversation_id in notes_data] + return gr.update(choices=choices, label=f"Found {total_count} notes") else: - return gr.update(choices=[]) + # This will now return all notes, ordered by timestamp + notes_data, total_pages, total_count = search_notes_titles("") + choices = [f"Note {note_id} - {title} ({timestamp})" for + note_id, title, content, timestamp, conversation_id in notes_data] + return gr.update(choices=choices, label=f"All notes ({total_count} total)") search_notes_button.click( search_notes, - inputs=[search_notes_by_keyword], + inputs=[search_notes_title, search_notes_by_keyword], outputs=[note_results] ) @@ -273,83 +407,112 @@ def clear_notes_function(): outputs=[notes, note_state] ) - def update_conversation_list(): - conversations, total_pages, total_count = get_all_conversations() - choices = [f"{title} (ID: {conversation_id})" for conversation_id, title in conversations] - return choices - # Initialize the conversation list load_conversation.choices = update_conversation_list() def load_conversation_history(selected_conversation, state_value): - if selected_conversation: - conversation_id = selected_conversation.split('(ID: ')[1][:-1] + try: + if not selected_conversation: + return [], state_value, "", gr.update(value="", visible=False), gr.update(visible=False) + # Extract conversation ID + match = re.search(r'\(ID: ([0-9a-fA-F\-]+)\)', selected_conversation) + if not match: + logging.error(f"Invalid conversation format: {selected_conversation}") + return [], state_value, "", gr.update(value="", visible=False), gr.update(visible=False) + conversation_id = match.group(1) chat_data, total_pages_val, _ = load_chat_history(conversation_id, 1, 50) - # Convert chat data to list of tuples (user_message, assistant_response) + # Update state with valid conversation id + updated_state = state_value.copy() + updated_state["conversation_id"] = conversation_id + updated_state["conversation_messages"] = chat_data + # Format chat history history = [] for role, content in chat_data: if role == 'user': history.append((content, '')) - else: - if history: - history[-1] = (history[-1][0], content) - else: - history.append(('', content)) - # Retrieve notes + elif history: + history[-1] = (history[-1][0], content) + # Fetch and display the conversation rating + rating = get_conversation_rating(conversation_id) + if rating is not None: + rating_text = f"**Current Rating:** {rating} star(s)" + rating_display_update = gr.update(value=rating_text, visible=True) + rating_input_update = gr.update(value=str(rating), visible=True) + else: + rating_display_update = gr.update(value="**Current Rating:** Not Rated", visible=True) + rating_input_update = gr.update(value=None, visible=True) notes_content = get_notes(conversation_id) - updated_state = update_state(state_value, conversation_id=conversation_id, page=1, - conversation_messages=[]) - return history, updated_state, "\n".join(notes_content) - return [], state_value, "" + return history, updated_state, "\n".join( + notes_content) if notes_content else "", rating_display_update, rating_input_update + except Exception as e: + logging.error(f"Error loading conversation: {str(e)}") + return [], state_value, "", gr.update(value="", visible=False), gr.update(visible=False) load_conversation.change( load_conversation_history, inputs=[load_conversation, state], - outputs=[chatbot, state, notes] + outputs=[chatbot, state, notes, rating_display, rating_input] ) # Modify save_conversation_function to use gr.update() - def save_conversation_function(conversation_title_text, keywords_text, state_value): + def save_conversation_function(conversation_title_text, keywords_text, rating_value, state_value): conversation_messages = state_value.get("conversation_messages", []) + conversation_id = state_value.get("conversation_id") if not conversation_messages: return gr.update( value="No conversation to save.
" - ), state_value, gr.update() - # Start a new conversation in the database - new_conversation_id = start_new_conversation( - conversation_title_text if conversation_title_text else "Untitled Conversation" - ) + ), state_value, gr.update(), gr.update(value="", visible=False), gr.update(visible=False) + # Start a new conversation in the database if not existing + if not conversation_id: + conversation_id = start_new_conversation( + conversation_title_text if conversation_title_text else "Untitled Conversation" + ) + else: + # Update the conversation title if it has changed + update_conversation_title(conversation_id, conversation_title_text) # Save the messages for role, content in conversation_messages: - save_message(new_conversation_id, role, content) + save_message(conversation_id, role, content) # Save keywords if provided if keywords_text: - add_keywords_to_conversation(new_conversation_id, [kw.strip() for kw in keywords_text.split(',')]) + add_keywords_to_conversation(conversation_id, [kw.strip() for kw in keywords_text.split(',')]) + # Save the rating if provided + try: + if rating_value: + set_conversation_rating(conversation_id, int(rating_value)) + except ValueError as ve: + logging.error(f"Invalid rating value: {ve}") + return gr.update( + value=f"Invalid rating: {ve}
" + ), state_value, gr.update(), gr.update(value="", visible=False), gr.update(visible=False) + # Update state - updated_state = update_state(state_value, conversation_id=new_conversation_id) + updated_state = update_state(state_value, conversation_id=conversation_id) # Update the conversation list conversation_choices = update_conversation_list() + # Reset rating display and input + rating_display_update = gr.update(value=f"**Current Rating:** {rating_value} star(s)", visible=True) + rating_input_update = gr.update(value=rating_value, visible=True) return gr.update( value="Conversation saved successfully.
" - ), updated_state, gr.update(choices=conversation_choices) + ), updated_state, gr.update(choices=conversation_choices), rating_display_update, rating_input_update save_conversation_button.click( save_conversation_function, - inputs=[conversation_title, keywords, state], - outputs=[status_message, state, load_conversation] + inputs=[conversation_title, keywords, rating_input, state], + outputs=[status_message, state, load_conversation, rating_display, rating_input] ) def start_new_conversation_wrapper(title, state_value): - # Reset the state with no conversation_id - updated_state = update_state(state_value, conversation_id=None, page=1, - conversation_messages=[]) - # Clear the chat history - return [], updated_state + # Reset the state with no conversation_id and empty conversation messages + updated_state = update_state(state_value, conversation_id=None, page=1, conversation_messages=[]) + # Clear the chat history and reset rating components + return [], updated_state, gr.update(value="", visible=False), gr.update(value=None, visible=False) new_conversation.click( start_new_conversation_wrapper, inputs=[conversation_title, state], - outputs=[chatbot, state] + outputs=[chatbot, state, rating_display, rating_input] ) def update_file_list(page): @@ -364,11 +527,12 @@ def prev_page_fn(current_page): return update_file_list(max(1, current_page - 1)) def update_context_source(choice): + # Update visibility based on context source choice return { existing_file: gr.update(visible=choice == "Existing File"), - prev_page_btn: gr.update(visible=choice == "Existing File"), - next_page_btn: gr.update(visible=choice == "Existing File"), - page_info: gr.update(visible=choice == "Existing File"), + prev_page_btn: gr.update(visible=choice == "Search Database"), + next_page_btn: gr.update(visible=choice == "Search Database"), + page_info: gr.update(visible=choice == "Search Database"), search_query: gr.update(visible=choice == "Search Database"), search_button: gr.update(visible=choice == "Search Database"), search_results: gr.update(visible=choice == "Search Database"), @@ -388,17 +552,36 @@ def update_context_source(choice): context_source.change(lambda choice: update_file_list(1) if choice == "Existing File" else (gr.update(), gr.update(), 1), inputs=[context_source], outputs=[existing_file, page_info, file_page]) - def perform_search(query): + def perform_search(query, selected_databases, keywords): try: - results = search_database(query) + results = [] + + # Iterate over selected database types and perform searches accordingly + for database_type in selected_databases: + if database_type == "Media DB": + # FIXME - check for existence of keywords before setting as search field + search_fields = ["title", "content", "keywords"] + results += search_media_db(query, search_fields, keywords, page=1, results_per_page=25) + elif database_type == "RAG Chat": + results += search_rag_chat(query) + elif database_type == "RAG Notes": + results += search_rag_notes(query) + elif database_type == "Character Chat": + results += search_character_chat(query) + elif database_type == "Character Cards": + results += search_character_cards(query) + + # Remove duplicate results if necessary + results = list(set(results)) return gr.update(choices=results) except Exception as e: gr.Error(f"Error performing search: {str(e)}") return gr.update(choices=[]) + # Click Event for the DB Search Button search_button.click( perform_search, - inputs=[search_query], + inputs=[search_query, db_choice, keywords_input], outputs=[search_results] ) @@ -420,17 +603,22 @@ def rephrase_question(history, latest_question, api_choice): logging.info(f"Rephrased question: {rephrased_question}") return rephrased_question.strip() - def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_results, file_upload, - convert_to_text, keywords, api_choice, use_query_rewriting, state_value, - keywords_input, top_k_input, use_re_ranking): + # FIXME - RAG DB selection + def rag_qa_chat_wrapper( + message, history, context_source, existing_file, search_results, file_upload, + convert_to_text, keywords, api_choice, use_query_rewriting, state_value, + keywords_input, top_k_input, use_re_ranking, db_choices, auto_save_enabled + ): try: logging.info(f"Starting rag_qa_chat_wrapper with message: {message}") logging.info(f"Context source: {context_source}") logging.info(f"API choice: {api_choice}") logging.info(f"Query rewriting: {'enabled' if use_query_rewriting else 'disabled'}") + logging.info(f"Selected DB Choices: {db_choices}") # Show loading indicator - yield history, "", gr.update(visible=True), state_value + yield history, "", gr.update(visible=True), state_value, gr.update(visible=False), gr.update( + visible=False) conversation_id = state_value.get("conversation_id") conversation_messages = state_value.get("conversation_messages", []) @@ -444,12 +632,12 @@ def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_ state_value["conversation_messages"] = conversation_messages # Ensure api_choice is a string - api_choice = api_choice.value if isinstance(api_choice, gr.components.Dropdown) else api_choice - logging.info(f"Resolved API choice: {api_choice}") + api_choice_str = api_choice.value if isinstance(api_choice, gr.components.Dropdown) else api_choice + logging.info(f"Resolved API choice: {api_choice_str}") # Only rephrase the question if it's not the first query and query rewriting is enabled if len(history) > 0 and use_query_rewriting: - rephrased_question = rephrase_question(history, message, api_choice) + rephrased_question = rephrase_question(history, message, api_choice_str) logging.info(f"Original question: {message}") logging.info(f"Rephrased question: {rephrased_question}") else: @@ -457,18 +645,20 @@ def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_ logging.info(f"Using original question: {message}") if context_source == "All Files in the Database": - # Use the enhanced_rag_pipeline to search the entire database - context = enhanced_rag_pipeline(rephrased_question, api_choice, keywords_input, top_k_input, - use_re_ranking) + # Use the enhanced_rag_pipeline to search the selected databases + context = enhanced_rag_pipeline( + rephrased_question, api_choice_str, keywords_input, top_k_input, use_re_ranking, + database_types=db_choices # Pass the list of selected databases + ) logging.info(f"Using enhanced_rag_pipeline for database search") elif context_source == "Search Database": context = f"media_id:{search_results.split('(ID: ')[1][:-1]}" logging.info(f"Using search result with context: {context}") - else: # Upload File + else: + # Upload File logging.info("Processing uploaded file") if file_upload is None: raise ValueError("No file uploaded") - # Process the uploaded file file_path = file_upload.name file_name = os.path.basename(file_path) @@ -481,7 +671,6 @@ def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_ logging.info("Reading file content") with open(file_path, 'r', encoding='utf-8') as f: content = f.read() - logging.info(f"File content length: {len(content)} characters") # Process keywords @@ -503,18 +692,17 @@ def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_ author='Unknown', ingestion_date=datetime.now().strftime('%Y-%m-%d') ) - logging.info(f"Result from add_media_with_keywords: {result}") if isinstance(result, tuple): media_id, _ = result else: media_id = result - context = f"media_id:{media_id}" logging.info(f"Context for uploaded file: {context}") logging.info("Calling rag_qa_chat function") - new_history, response = rag_qa_chat(rephrased_question, history, context, api_choice) + new_history, response = rag_qa_chat(rephrased_question, history, context, api_choice_str) + # Log first 100 chars of response logging.info(f"Response received from rag_qa_chat: {response[:100]}...") @@ -526,7 +714,8 @@ def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_ state_value["conversation_messages"] = conversation_messages # Update the state - state_value["conversation_messages"] = conversation_messages + updated_state = auto_save_conversation(message, response, state_value, auto_save_enabled) + updated_state["conversation_messages"] = conversation_messages # Safely update history if new_history: @@ -534,24 +723,43 @@ def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_ else: new_history = [(message, response)] + # Get the current rating and update display + conversation_id = updated_state.get("conversation_id") + if conversation_id: + rating = get_conversation_rating(conversation_id) + if rating is not None: + rating_display_update = gr.update(value=f"**Current Rating:** {rating} star(s)", visible=True) + rating_input_update = gr.update(value=str(rating), visible=True) + else: + rating_display_update = gr.update(value="**Current Rating:** Not Rated", visible=True) + rating_input_update = gr.update(value=None, visible=True) + else: + rating_display_update = gr.update(value="", visible=False) + rating_input_update = gr.update(value=None, visible=False) + gr.Info("Response generated successfully") logging.info("rag_qa_chat_wrapper completed successfully") - yield new_history, "", gr.update(visible=False), state_value # Include state_value in outputs + yield new_history, "", gr.update( + visible=False), updated_state, rating_display_update, rating_input_update + except ValueError as e: logging.error(f"Input error in rag_qa_chat_wrapper: {str(e)}") gr.Error(f"Input error: {str(e)}") - yield history, "", gr.update(visible=False), state_value + yield history, "", gr.update(visible=False), state_value, gr.update(visible=False), gr.update( + visible=False) except DatabaseError as e: logging.error(f"Database error in rag_qa_chat_wrapper: {str(e)}") gr.Error(f"Database error: {str(e)}") - yield history, "", gr.update(visible=False), state_value + yield history, "", gr.update(visible=False), state_value, gr.update(visible=False), gr.update( + visible=False) except Exception as e: logging.error(f"Unexpected error in rag_qa_chat_wrapper: {e}", exc_info=True) gr.Error("An unexpected error occurred. Please try again later.") - yield history, "", gr.update(visible=False), state_value + yield history, "", gr.update(visible=False), state_value, gr.update(visible=False), gr.update( + visible=False) def clear_chat_history(): - return [], "" + return [], "", gr.update(value="", visible=False), gr.update(value=None, visible=False) submit.click( rag_qa_chat_wrapper, @@ -568,14 +776,17 @@ def clear_chat_history(): use_query_rewriting, state, keywords_input, - top_k_input + top_k_input, + use_re_ranking, + db_choice, + auto_save_checkbox ], - outputs=[chatbot, msg, loading_indicator, state], + outputs=[chatbot, msg, loading_indicator, state, rating_display, rating_input], ) clear_chat.click( clear_chat_history, - outputs=[chatbot, msg] + outputs=[chatbot, msg, rating_display, rating_input] ) return ( @@ -608,7 +819,8 @@ def create_rag_qa_notes_management_tab(): with gr.Row(): with gr.Column(scale=1): # Search Notes - search_notes_input = gr.Textbox(label="Search Notes by Keywords") + search_notes_title = gr.Textbox(label="Search Notes by Title") + search_notes_by_keyword = gr.Textbox(label="Search Notes by Keywords") search_notes_button = gr.Button("Search Notes") notes_list = gr.Dropdown(label="Notes", choices=[]) @@ -617,24 +829,34 @@ def create_rag_qa_notes_management_tab(): delete_note_button = gr.Button("Delete Note") note_title_input = gr.Textbox(label="Note Title") note_content_input = gr.TextArea(label="Note Content", lines=20) - note_keywords_input = gr.Textbox(label="Note Keywords (comma-separated)") + note_keywords_input = gr.Textbox(label="Note Keywords (comma-separated)", value="default_note_keyword") save_note_button = gr.Button("Save Note") create_new_note_button = gr.Button("Create New Note") status_message = gr.HTML() # Function Definitions - def search_notes(keywords): + def search_notes(search_notes_title, keywords): if keywords: keywords_list = [kw.strip() for kw in keywords.split(',')] notes_data, total_pages, total_count = get_notes_by_keywords(keywords_list) - choices = [f"{title} (ID: {note_id})" for note_id, title, content, timestamp in notes_data] - return gr.update(choices=choices) + choices = [f"Note {note_id} - {title} ({timestamp})" for + note_id, title, content, timestamp, conversation_id in notes_data] + return gr.update(choices=choices, label=f"Found {total_count} notes") + elif search_notes_title: + notes_data, total_pages, total_count = search_notes_titles(search_notes_title) + choices = [f"Note {note_id} - {title} ({timestamp})" for + note_id, title, content, timestamp, conversation_id in notes_data] + return gr.update(choices=choices, label=f"Found {total_count} notes") else: - return gr.update(choices=[]) + # This will now return all notes, ordered by timestamp + notes_data, total_pages, total_count = search_notes_titles("") + choices = [f"Note {note_id} - {title} ({timestamp})" for + note_id, title, content, timestamp, conversation_id in notes_data] + return gr.update(choices=choices, label=f"All notes ({total_count} total)") search_notes_button.click( search_notes, - inputs=[search_notes_input], + inputs=[search_notes_title, search_notes_by_keyword], outputs=[notes_list] ) @@ -698,7 +920,7 @@ def delete_selected_note(state_value): # Reset state state_value["selected_note_id"] = None # Update notes list - updated_notes = search_notes("") + updated_notes = search_notes("", "") return updated_notes, gr.update(value="Note deleted successfully."), state_value else: return gr.update(), gr.update(value="No note selected."), state_value @@ -736,7 +958,20 @@ def create_rag_qa_chat_management_tab(): with gr.Row(): with gr.Column(scale=1): # Search Conversations - search_conversations_input = gr.Textbox(label="Search Conversations by Keywords") + with gr.Group(): + gr.Markdown("## Search Conversations") + title_search = gr.Textbox( + label="Search by Title", + placeholder="Enter title to search..." + ) + content_search = gr.Textbox( + label="Search in Chat Content", + placeholder="Enter text to search in messages..." + ) + keyword_search = gr.Textbox( + label="Filter by Keywords (comma-separated)", + placeholder="keyword1, keyword2, ..." + ) search_conversations_button = gr.Button("Search Conversations") conversations_list = gr.Dropdown(label="Conversations", choices=[]) new_conversation_button = gr.Button("New Conversation") @@ -750,26 +985,40 @@ def create_rag_qa_chat_management_tab(): status_message = gr.HTML() # Function Definitions - def search_conversations(keywords): - if keywords: - keywords_list = [kw.strip() for kw in keywords.split(',')] - conversations, total_pages, total_count = search_conversations_by_keywords(keywords_list) - else: - conversations, total_pages, total_count = get_all_conversations() + def search_conversations(title_query, content_query, keywords): + try: + # Parse keywords if provided + keywords_list = None + if keywords and keywords.strip(): + keywords_list = [kw.strip() for kw in keywords.split(',')] + + # Search using existing search_conversations_by_keywords function with all criteria + results, total_pages, total_count = search_conversations_by_keywords( + keywords=keywords_list, + title_query=title_query if title_query.strip() else None, + content_query=content_query if content_query.strip() else None + ) - # Build choices as list of titles (ensure uniqueness) - choices = [] - mapping = {} - for conversation_id, title in conversations: - display_title = f"{title} (ID: {conversation_id[:8]})" - choices.append(display_title) - mapping[display_title] = conversation_id + # Build choices as list of titles (ensure uniqueness) + choices = [] + mapping = {} + for conv in results: + conversation_id = conv['conversation_id'] + title = conv['title'] + display_title = f"{title} (ID: {conversation_id[:8]})" + choices.append(display_title) + mapping[display_title] = conversation_id + + return gr.update(choices=choices), mapping - return gr.update(choices=choices), mapping + except Exception as e: + logging.error(f"Error in search_conversations: {str(e)}") + return gr.update(choices=[]), {} + # Update the search button click event search_conversations_button.click( search_conversations, - inputs=[search_conversations_input], + inputs=[title_search, content_search, keyword_search], outputs=[conversations_list, conversation_mapping] ) @@ -926,19 +1175,18 @@ def create_new_conversation(state_value, mapping): ] ) - def delete_messages_in_conversation(conversation_id): - """Helper function to delete all messages in a conversation.""" + def delete_messages_in_conversation_wrapper(conversation_id): + """Wrapper function to delete all messages in a conversation.""" try: - execute_query("DELETE FROM rag_qa_chats WHERE conversation_id = ?", (conversation_id,)) + delete_messages_in_conversation(conversation_id) logging.info(f"Messages in conversation '{conversation_id}' deleted successfully.") except Exception as e: logging.error(f"Error deleting messages in conversation '{conversation_id}': {e}") raise - def get_conversation_title(conversation_id): + def get_conversation_title_wrapper(conversation_id): """Helper function to get the conversation title.""" - query = "SELECT title FROM conversation_metadata WHERE conversation_id = ?" - result = execute_query(query, (conversation_id,)) + result = get_conversation_title(conversation_id) if result: return result[0][0] else: @@ -1068,19 +1316,6 @@ def export_data_function(export_option, selected_conversations, selected_notes): ) - - -def update_conversation_title(conversation_id, new_title): - """Update the title of a conversation.""" - try: - query = "UPDATE conversation_metadata SET title = ? WHERE conversation_id = ?" - execute_query(query, (new_title, conversation_id)) - logging.info(f"Conversation '{conversation_id}' title updated to '{new_title}'") - except Exception as e: - logging.error(f"Error updating conversation title: {e}") - raise - - def convert_file_to_text(file_path): """Convert various file types to plain text.""" file_extension = os.path.splitext(file_path)[1].lower() diff --git a/App_Function_Libraries/Gradio_UI/Re_summarize_tab.py b/App_Function_Libraries/Gradio_UI/Re_summarize_tab.py index f0181feda..ca9f33170 100644 --- a/App_Function_Libraries/Gradio_UI/Re_summarize_tab.py +++ b/App_Function_Libraries/Gradio_UI/Re_summarize_tab.py @@ -10,16 +10,13 @@ # # Local Imports from App_Function_Libraries.Chunk_Lib import improved_chunking_process -from App_Function_Libraries.DB.DB_Manager import update_media_content, load_preset_prompts +from App_Function_Libraries.DB.DB_Manager import update_media_content, list_prompts from App_Function_Libraries.Gradio_UI.Chat_ui import update_user_prompt from App_Function_Libraries.Gradio_UI.Gradio_Shared import fetch_item_details, fetch_items_by_keyword, \ fetch_items_by_content, fetch_items_by_title_or_url from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_chunk from App_Function_Libraries.Utils.Utils import load_comprehensive_config, default_api_endpoint, global_api_endpoints, \ format_api_name - - -# # ###################################################################################################################### # @@ -36,6 +33,10 @@ def create_resummary_tab(): except Exception as e: logging.error(f"Error setting default API endpoint: {str(e)}") default_value = None + + # Get initial prompts for first page + initial_prompts, total_pages, current_page = list_prompts(page=1, per_page=20) + with gr.TabItem("Re-Summarize", visible=True): gr.Markdown("# Re-Summarize Existing Content") with gr.Row(): @@ -48,7 +49,6 @@ def create_resummary_tab(): item_mapping = gr.State({}) with gr.Row(): - # Refactored API selection dropdown api_name_input = gr.Dropdown( choices=["None"] + [format_api_name(api) for api in global_api_endpoints], value=default_value, @@ -70,9 +70,17 @@ def create_resummary_tab(): preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", value=False, visible=True) + + # Add pagination controls for preset prompts + with gr.Row(visible=False) as preset_prompt_controls: + prev_page = gr.Button("Previous") + current_page_text = gr.Text(f"Page {current_page} of {total_pages}") + next_page = gr.Button("Next") + current_page_state = gr.State(value=1) + with gr.Row(): preset_prompt = gr.Dropdown(label="Select Preset Prompt", - choices=load_preset_prompts(), + choices=initial_prompts, visible=False) with gr.Row(): custom_prompt_input = gr.Textbox(label="Custom Prompt", @@ -101,6 +109,15 @@ def create_resummary_tab(): lines=3, visible=False) + def update_prompt_page(direction, current_page_val): + new_page = max(1, min(total_pages, current_page_val + direction)) + prompts, _, _ = list_prompts(page=new_page, per_page=10) + return ( + gr.update(choices=prompts), + gr.update(value=f"Page {new_page} of {total_pages}"), + new_page + ) + def update_prompts(preset_name): prompts = update_user_prompt(preset_name) return ( @@ -108,6 +125,19 @@ def update_prompts(preset_name): gr.update(value=prompts["system_prompt"], visible=True) ) + # Connect pagination buttons + prev_page.click( + lambda x: update_prompt_page(-1, x), + inputs=[current_page_state], + outputs=[preset_prompt, current_page_text, current_page_state] + ) + + next_page.click( + lambda x: update_prompt_page(1, x), + inputs=[current_page_state], + outputs=[preset_prompt, current_page_text, current_page_state] + ) + preset_prompt.change( update_prompts, inputs=preset_prompt, @@ -124,9 +154,9 @@ def update_prompts(preset_name): outputs=[custom_prompt_input, system_prompt_input] ) preset_prompt_checkbox.change( - fn=lambda x: gr.update(visible=x), + fn=lambda x: (gr.update(visible=x), gr.update(visible=x)), inputs=[preset_prompt_checkbox], - outputs=[preset_prompt] + outputs=[preset_prompt, preset_prompt_controls] ) # Connect the UI elements @@ -155,7 +185,12 @@ def update_prompts(preset_name): outputs=result_output ) - return search_query_input, search_type_input, search_button, items_output, item_mapping, api_name_input, api_key_input, chunking_options_checkbox, chunking_options_box, chunk_method, max_chunk_size, chunk_overlap, custom_prompt_checkbox, custom_prompt_input, resummarize_button, result_output + return ( + search_query_input, search_type_input, search_button, items_output, + item_mapping, api_name_input, api_key_input, chunking_options_checkbox, + chunking_options_box, chunk_method, max_chunk_size, chunk_overlap, + custom_prompt_checkbox, custom_prompt_input, resummarize_button, result_output + ) def update_resummarize_dropdown(search_query, search_type): diff --git a/App_Function_Libraries/Gradio_UI/Search_Tab.py b/App_Function_Libraries/Gradio_UI/Search_Tab.py index 2ad075b81..64c9527b5 100644 --- a/App_Function_Libraries/Gradio_UI/Search_Tab.py +++ b/App_Function_Libraries/Gradio_UI/Search_Tab.py @@ -11,8 +11,8 @@ # # Local Imports from App_Function_Libraries.DB.DB_Manager import view_database, search_and_display_items, get_all_document_versions, \ - fetch_item_details_single, fetch_paginated_data, fetch_item_details, get_latest_transcription -from App_Function_Libraries.DB.SQLite_DB import search_prompts, get_document_version + fetch_item_details_single, fetch_paginated_data, fetch_item_details, get_latest_transcription, search_prompts, \ + get_document_version from App_Function_Libraries.Gradio_UI.Gradio_Shared import update_dropdown, update_detailed_view from App_Function_Libraries.Utils.Utils import get_database_path, format_text_with_line_breaks # @@ -80,8 +80,8 @@ def format_as_html(content, title): """ def create_search_tab(): - with gr.TabItem("Search / Detailed View", visible=True): - gr.Markdown("# Search across all ingested items in the Database") + with gr.TabItem("Media DB Search / Detailed View", visible=True): + gr.Markdown("# Search across all ingested items in the Media Database") with gr.Row(): with gr.Column(scale=1): gr.Markdown("by Title / URL / Keyword / or Content via SQLite Full-Text-Search") @@ -150,8 +150,8 @@ def display_search_results(query): def create_search_summaries_tab(): - with gr.TabItem("Search/View Title+Summary", visible=True): - gr.Markdown("# Search across all ingested items in the Database and review their summaries") + with gr.TabItem("Media DB Search/View Title+Summary", visible=True): + gr.Markdown("# Search across all ingested items in the Media Database and review their summaries") gr.Markdown("Search by Title / URL / Keyword / or Content via SQLite Full-Text-Search") with gr.Row(): with gr.Column(): diff --git a/App_Function_Libraries/Gradio_UI/Video_transcription_tab.py b/App_Function_Libraries/Gradio_UI/Video_transcription_tab.py index 20c5bc90d..3d03af228 100644 --- a/App_Function_Libraries/Gradio_UI/Video_transcription_tab.py +++ b/App_Function_Libraries/Gradio_UI/Video_transcription_tab.py @@ -10,10 +10,12 @@ # External Imports import gradio as gr import yt_dlp + +from App_Function_Libraries.Chunk_Lib import improved_chunking_process # # Local Imports -from App_Function_Libraries.DB.DB_Manager import load_preset_prompts, add_media_to_database, \ - check_media_and_whisper_model, check_existing_media, update_media_content_with_version +from App_Function_Libraries.DB.DB_Manager import add_media_to_database, \ + check_media_and_whisper_model, check_existing_media, update_media_content_with_version, list_prompts from App_Function_Libraries.Gradio_UI.Gradio_Shared import whisper_models, update_user_prompt from App_Function_Libraries.Gradio_UI.Gradio_Shared import error_handler from App_Function_Libraries.Summarization.Summarization_General_Lib import perform_transcription, perform_summarization, \ @@ -65,15 +67,20 @@ def create_video_transcription_tab(): preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", value=False, visible=True) + + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + with gr.Row(): + # Add pagination controls preset_prompt = gr.Dropdown(label="Select Preset Prompt", - choices=load_preset_prompts(), + choices=[], visible=False) with gr.Row(): - custom_prompt_input = gr.Textbox(label="Custom Prompt", - placeholder="Enter custom prompt here", - lines=3, - visible=False) + prev_page_button = gr.Button("Previous Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + next_page_button = gr.Button("Next Page", visible=False) with gr.Row(): system_prompt_input = gr.Textbox(label="System Prompt", value="""Title | Author |
---|---|
{html.escape(title)} | {html.escape(author)} |
Error fetching prompts: {e}
", "Error", 0, [] - - def update_page(page, entries_per_page): - results, pagination, total_pages, prompt_choices = view_database(page, entries_per_page) - next_disabled = page >= total_pages - prev_disabled = page <= 1 - return results, pagination, page, gr.update(interactive=not next_disabled), gr.update( - interactive=not prev_disabled), gr.update(choices=prompt_choices) - - def go_to_next_page(current_page, entries_per_page): - next_page = current_page + 1 - return update_page(next_page, entries_per_page) - - def go_to_previous_page(current_page, entries_per_page): - previous_page = max(1, current_page - 1) - return update_page(previous_page, entries_per_page) - - def display_selected_prompt(prompt_name): - details = fetch_prompt_details(prompt_name) - if details: - title, author, description, system_prompt, user_prompt, keywords = details - # Handle None values by converting them to empty strings - description = description or "" - system_prompt = system_prompt or "" - user_prompt = user_prompt or "" - author = author or "Unknown" - keywords = keywords or "" - - html_content = f""" -Description: {html.escape(description)}
-{html.escape(system_prompt)}-
{html.escape(user_prompt)}-
Keywords: {html.escape(keywords)}
-Prompt not found.
" - - view_button.click( - fn=update_page, - inputs=[page_number, entries_per_page], - outputs=[results_table, pagination_info, page_number, next_page_button, previous_page_button, - prompt_selector] - ) - - next_page_button.click( - fn=go_to_next_page, - inputs=[page_number, entries_per_page], - outputs=[results_table, pagination_info, page_number, next_page_button, previous_page_button, - prompt_selector] - ) - - previous_page_button.click( - fn=go_to_previous_page, - inputs=[page_number, entries_per_page], - outputs=[results_table, pagination_info, page_number, next_page_button, previous_page_button, - prompt_selector] - ) - - prompt_selector.change( - fn=display_selected_prompt, - inputs=[prompt_selector], - outputs=[selected_prompt_display] - ) - def format_as_html(content, title): escaped_content = html.escape(content) formatted_content = escaped_content.replace('\n', 'Please select at least one keyword.
", "Please select at least one keyword.
", @@ -802,14 +713,17 @@ def view_items(keywords, page, entries_per_page): ) try: + # Ensure keywords is a list + keywords_list = keywords if isinstance(keywords, list) else [keywords] + # Get conversations for selected keywords conversations, conv_total_pages, conv_count = search_conversations_by_keywords( - keywords, page, entries_per_page + keywords_list, page, entries_per_page ) # Get notes for selected keywords notes, notes_total_pages, notes_count = get_notes_by_keywords( - keywords, page, entries_per_page + keywords_list, page, entries_per_page ) # Format results as HTML @@ -833,6 +747,7 @@ def view_items(keywords, page, entries_per_page): gr.update(interactive=not prev_disabled) ) except Exception as e: + logging.error(f"Error in view_items: {str(e)}") return ( f"Error: {str(e)}
", f"Error: {str(e)}
", diff --git a/App_Function_Libraries/Gradio_UI/Website_scraping_tab.py b/App_Function_Libraries/Gradio_UI/Website_scraping_tab.py index 80b19ba9f..5204547e0 100644 --- a/App_Function_Libraries/Gradio_UI/Website_scraping_tab.py +++ b/App_Function_Libraries/Gradio_UI/Website_scraping_tab.py @@ -22,7 +22,7 @@ # Local Imports from App_Function_Libraries.Web_Scraping.Article_Extractor_Lib import scrape_from_sitemap, scrape_by_url_level, \ scrape_article, collect_bookmarks, scrape_and_summarize_multiple, collect_urls_from_file -from App_Function_Libraries.DB.DB_Manager import load_preset_prompts +from App_Function_Libraries.DB.DB_Manager import list_prompts from App_Function_Libraries.Gradio_UI.Chat_ui import update_user_prompt from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize @@ -314,12 +314,22 @@ def create_website_scraping_tab(): preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", value=False, visible=True) with gr.Row(): temp_slider = gr.Slider(0.1, 2.0, 0.7, label="Temperature") + + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) with gr.Row(): + # Add pagination controls preset_prompt = gr.Dropdown( label="Select Preset Prompt", - choices=load_preset_prompts(), + choices=[], visible=False ) + with gr.Row(): + prev_page_button = gr.Button("Previous Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + next_page_button = gr.Button("Next Page", visible=False) + with gr.Row(): website_custom_prompt_input = gr.Textbox( label="Custom Prompt", @@ -421,10 +431,57 @@ def update_ui_for_scrape_method(method): inputs=[custom_prompt_checkbox], outputs=[website_custom_prompt_input, system_prompt_input] ) + + def on_preset_prompt_checkbox_change(is_checked): + if is_checked: + prompts, total_pages, current_page = list_prompts(page=1, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(visible=True, interactive=True, choices=prompts), # preset_prompt + gr.update(visible=True), # prev_page_button + gr.update(visible=True), # next_page_button + gr.update(value=page_display_text, visible=True), # page_display + current_page, # current_page_state + total_pages # total_pages_state + ) + else: + return ( + gr.update(visible=False, interactive=False), # preset_prompt + gr.update(visible=False), # prev_page_button + gr.update(visible=False), # next_page_button + gr.update(visible=False), # page_display + 1, # current_page_state + 1 # total_pages_state + ) + preset_prompt_checkbox.change( - fn=lambda x: gr.update(visible=x), + fn=on_preset_prompt_checkbox_change, inputs=[preset_prompt_checkbox], - outputs=[preset_prompt] + outputs=[preset_prompt, prev_page_button, next_page_button, page_display, current_page_state, total_pages_state] + ) + + def on_prev_page_click(current_page, total_pages): + new_page = max(current_page - 1, 1) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] ) def update_prompts(preset_name): diff --git a/App_Function_Libraries/Gradio_UI/Chat_Workflows.py b/App_Function_Libraries/Gradio_UI/Workflows_tab.py similarity index 96% rename from App_Function_Libraries/Gradio_UI/Chat_Workflows.py rename to App_Function_Libraries/Gradio_UI/Workflows_tab.py index 802df6de9..5c911d290 100644 --- a/App_Function_Libraries/Gradio_UI/Chat_Workflows.py +++ b/App_Function_Libraries/Gradio_UI/Workflows_tab.py @@ -1,5 +1,5 @@ # Chat_Workflows.py -# Description: UI for Chat Workflows +# Description: Gradio UI for Chat Workflows # # Imports import json @@ -9,11 +9,11 @@ # External Imports import gradio as gr # +# Local Imports from App_Function_Libraries.Gradio_UI.Chat_ui import chat_wrapper, search_conversations, \ load_conversation -from App_Function_Libraries.Chat import save_chat_history_to_db_wrapper +from App_Function_Libraries.Chat.Chat_Functions import save_chat_history_to_db_wrapper from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name - # ############################################################################################################ # @@ -74,6 +74,7 @@ def chat_workflows_tab(): clear_btn = gr.Button("Clear Chat") chat_media_name = gr.Textbox(label="Custom Chat Name(optional)") save_btn = gr.Button("Save Chat to Database") + save_status = gr.Textbox(label="Save Status", interactive=False) def update_workflow_ui(workflow_name): if not workflow_name: @@ -164,7 +165,7 @@ def process_workflow_step(message, history, context, workflow_name, api_endpoint save_btn.click( save_chat_history_to_db_wrapper, inputs=[chatbot, conversation_id, media_content, chat_media_name], - outputs=[conversation_id, gr.Textbox(label="Save Status")] + outputs=[conversation_id, save_status] ) search_conversations_btn.click( diff --git a/App_Function_Libraries/Gradio_UI/Writing_tab.py b/App_Function_Libraries/Gradio_UI/Writing_tab.py index 37f5a8fdd..306513314 100644 --- a/App_Function_Libraries/Gradio_UI/Writing_tab.py +++ b/App_Function_Libraries/Gradio_UI/Writing_tab.py @@ -46,17 +46,17 @@ def grammar_style_check(input_text, custom_prompt, api_name, api_key, system_pro def create_grammar_style_check_tab(): - try: - default_value = None - if default_api_endpoint: - if default_api_endpoint in global_api_endpoints: - default_value = format_api_name(default_api_endpoint) - else: - logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") - except Exception as e: - logging.error(f"Error setting default API endpoint: {str(e)}") - default_value = None with gr.TabItem("Grammar and Style Check", visible=True): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None with gr.Row(): with gr.Column(): gr.Markdown("# Grammar and Style Check") @@ -317,63 +317,63 @@ def create_document_feedback_tab(): with gr.Row(): compare_button = gr.Button("Compare Feedback") - feedback_history = gr.State([]) - - def add_custom_persona(name, description): - updated_choices = persona_dropdown.choices + [name] - persona_prompts[name] = f"As {name}, {description}, provide feedback on the following text:" - return gr.update(choices=updated_choices) - - def update_feedback_history(current_text, persona, feedback): - # Ensure feedback_history.value is initialized and is a list - if feedback_history.value is None: - feedback_history.value = [] - - history = feedback_history.value - - # Append the new entry to the history - history.append({"text": current_text, "persona": persona, "feedback": feedback}) - - # Keep only the last 5 entries in the history - feedback_history.value = history[-10:] - - # Generate and return the updated HTML - return generate_feedback_history_html(feedback_history.value) - - def compare_feedback(text, selected_personas, api_name, api_key): - results = [] - for persona in selected_personas: - feedback = generate_writing_feedback(text, persona, "Overall", api_name, api_key) - results.append(f"### {persona}'s Feedback:\n{feedback}\n\n") - return "\n".join(results) - - add_custom_persona_button.click( - fn=add_custom_persona, - inputs=[custom_persona_name, custom_persona_description], - outputs=persona_dropdown - ) - - get_feedback_button.click( - fn=lambda text, persona, aspect, api_name, api_key: ( - generate_writing_feedback(text, persona, aspect, api_name, api_key), - calculate_readability(text), - update_feedback_history(text, persona, generate_writing_feedback(text, persona, aspect, api_name, api_key)) - ), - inputs=[input_text, persona_dropdown, aspect_dropdown, api_name_input, api_key_input], - outputs=[feedback_output, readability_output, feedback_history_display] - ) - - compare_button.click( - fn=compare_feedback, - inputs=[input_text, compare_personas, api_name_input, api_key_input], - outputs=feedback_output - ) - - generate_prompt_button.click( - fn=generate_writing_prompt, - inputs=[persona_dropdown, api_name_input, api_key_input], - outputs=input_text - ) + feedback_history = gr.State([]) + + def add_custom_persona(name, description): + updated_choices = persona_dropdown.choices + [name] + persona_prompts[name] = f"As {name}, {description}, provide feedback on the following text:" + return gr.update(choices=updated_choices) + + def update_feedback_history(current_text, persona, feedback): + # Ensure feedback_history.value is initialized and is a list + if feedback_history.value is None: + feedback_history.value = [] + + history = feedback_history.value + + # Append the new entry to the history + history.append({"text": current_text, "persona": persona, "feedback": feedback}) + + # Keep only the last 5 entries in the history + feedback_history.value = history[-10:] + + # Generate and return the updated HTML + return generate_feedback_history_html(feedback_history.value) + + def compare_feedback(text, selected_personas, api_name, api_key): + results = [] + for persona in selected_personas: + feedback = generate_writing_feedback(text, persona, "Overall", api_name, api_key) + results.append(f"### {persona}'s Feedback:\n{feedback}\n\n") + return "\n".join(results) + + add_custom_persona_button.click( + fn=add_custom_persona, + inputs=[custom_persona_name, custom_persona_description], + outputs=persona_dropdown + ) + + get_feedback_button.click( + fn=lambda text, persona, aspect, api_name, api_key: ( + generate_writing_feedback(text, persona, aspect, api_name, api_key), + calculate_readability(text), + update_feedback_history(text, persona, generate_writing_feedback(text, persona, aspect, api_name, api_key)) + ), + inputs=[input_text, persona_dropdown, aspect_dropdown, api_name_input, api_key_input], + outputs=[feedback_output, readability_output, feedback_history_display] + ) + + compare_button.click( + fn=compare_feedback, + inputs=[input_text, compare_personas, api_name_input, api_key_input], + outputs=feedback_output + ) + + generate_prompt_button.click( + fn=generate_writing_prompt, + inputs=[persona_dropdown, api_name_input, api_key_input], + outputs=input_text + ) return input_text, feedback_output, readability_output, feedback_history_display diff --git a/App_Function_Libraries/Plaintext/Plaintext_Files.py b/App_Function_Libraries/Plaintext/Plaintext_Files.py index f5038a967..ba3c686db 100644 --- a/App_Function_Libraries/Plaintext/Plaintext_Files.py +++ b/App_Function_Libraries/Plaintext/Plaintext_Files.py @@ -2,9 +2,12 @@ # Description: This file contains functions for reading and writing plaintext files. # # Import necessary libraries +import logging import os import tempfile import zipfile +from datetime import datetime + # # External Imports from docx2txt import docx2txt @@ -12,57 +15,161 @@ # # Local Imports from App_Function_Libraries.Gradio_UI.Import_Functionality import import_data +from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram + + # ####################################################################################################################### # # Function Definitions -def import_plain_text_file(file_path, title, author, keywords, system_prompt, user_prompt, auto_summarize, api_name, - api_key): +def import_plain_text_file(file_path, author, keywords, system_prompt, user_prompt, auto_summarize, api_name, api_key): + """Import a single plain text file.""" try: + log_counter("file_processing_attempt", labels={"file_path": file_path}) + + # Extract title from filename + title = os.path.splitext(os.path.basename(file_path))[0] + # Determine the file type and convert if necessary file_extension = os.path.splitext(file_path)[1].lower() - if file_extension == '.rtf': - with tempfile.NamedTemporaryFile(suffix='.md', delete=False) as temp_file: - convert_file(file_path, 'md', outputfile=temp_file.name) - file_path = temp_file.name - elif file_extension == '.docx': - content = docx2txt.process(file_path) - else: - with open(file_path, 'r', encoding='utf-8') as file: - content = file.read() - - # Process the content - return import_data(content, title, author, keywords, system_prompt, - user_prompt, auto_summarize, api_name, api_key) + + # Get the content based on file type + try: + if file_extension == '.rtf': + with tempfile.NamedTemporaryFile(suffix='.md', delete=False) as temp_file: + convert_file(file_path, 'md', outputfile=temp_file.name) + file_path = temp_file.name + with open(file_path, 'r', encoding='utf-8') as file: + content = file.read() + log_counter("rtf_conversion_success", labels={"file_path": file_path}) + elif file_extension == '.docx': + content = docx2txt.process(file_path) + log_counter("docx_conversion_success", labels={"file_path": file_path}) + else: + with open(file_path, 'r', encoding='utf-8') as file: + content = file.read() + except Exception as e: + logging.error(f"Error reading file content: {str(e)}") + return f"Error reading file content: {str(e)}" + + # Import the content + result = import_data( + content, # Pass the content directly + title, + author, + keywords, + user_prompt, # This is the custom_prompt parameter + None, # No summary - let auto_summarize handle it + auto_summarize, + api_name, + api_key + ) + + log_counter("file_processing_success", labels={"file_path": file_path}) + return result + except Exception as e: - return f"Error processing file: {str(e)}" + logging.exception(f"Error processing file {file_path}") + log_counter("file_processing_error", labels={"file_path": file_path, "error": str(e)}) + return f"Error processing file {os.path.basename(file_path)}: {str(e)}" + -def process_plain_text_zip_file(zip_file, title, author, keywords, system_prompt, user_prompt, auto_summarize, api_name, api_key): +def process_plain_text_zip_file(zip_file, author, keywords, system_prompt, user_prompt, auto_summarize, api_name, api_key): + """Process multiple text files from a zip archive.""" results = [] - with tempfile.TemporaryDirectory() as temp_dir: - with zipfile.ZipFile(zip_file.name, 'r') as zip_ref: - zip_ref.extractall(temp_dir) - - for filename in os.listdir(temp_dir): - if filename.lower().endswith(('.md', '.txt', '.rtf', '.docx')): - file_path = os.path.join(temp_dir, filename) - result = import_plain_text_file(file_path, title, author, keywords, system_prompt, - user_prompt, auto_summarize, api_name, api_key) - results.append(f"File: {filename} - {result}") - - return "\n".join(results) - - -def import_file_handler(file, title, author, keywords, system_prompt, user_prompt, auto_summarize, api_name, api_key): - if file.name.lower().endswith(('.md', '.txt', '.rtf', '.docx')): - return import_plain_text_file(file.name, title, author, keywords, system_prompt, user_prompt, auto_summarize, - api_name, api_key) - elif file.name.lower().endswith('.zip'): - return process_plain_text_zip_file(file, title, author, keywords, system_prompt, user_prompt, auto_summarize, - api_name, api_key) - else: - return "Unsupported file type. Please upload a .md, .txt, .rtf, .docx file or a .zip file containing these file types." + try: + with tempfile.TemporaryDirectory() as temp_dir: + with zipfile.ZipFile(zip_file.name, 'r') as zip_ref: + zip_ref.extractall(temp_dir) + + for filename in os.listdir(temp_dir): + if filename.lower().endswith(('.md', '.txt', '.rtf', '.docx')): + file_path = os.path.join(temp_dir, filename) + result = import_plain_text_file( + file_path=file_path, + author=author, + keywords=keywords, + system_prompt=system_prompt, + user_prompt=user_prompt, + auto_summarize=auto_summarize, + api_name=api_name, + api_key=api_key + ) + results.append(f"📄 {filename}: {result}") + + return "\n\n".join(results) + except Exception as e: + logging.exception(f"Error processing zip file: {str(e)}") + return f"Error processing zip file: {str(e)}" + + + +def import_file_handler(files, author, keywords, system_prompt, user_prompt, auto_summarize, api_name, api_key): + """Handle the import of one or more files, including zip files.""" + try: + if not files: + log_counter("plaintext_import_error", labels={"error": "No files uploaded"}) + return "No files uploaded." + + # Convert single file to list for consistent processing + if not isinstance(files, list): + files = [files] + + results = [] + for file in files: + log_counter("plaintext_import_attempt", labels={"file_name": file.name}) + + start_time = datetime.now() + + if not os.path.exists(file.name): + log_counter("plaintext_import_error", labels={"error": "File not found", "file_name": file.name}) + results.append(f"❌ File not found: {file.name}") + continue + + if file.name.lower().endswith(('.md', '.txt', '.rtf', '.docx')): + result = import_plain_text_file( + file_path=file.name, + author=author, + keywords=keywords, + system_prompt=system_prompt, + user_prompt=user_prompt, + auto_summarize=auto_summarize, + api_name=api_name, + api_key=api_key + ) + log_counter("plaintext_import_success", labels={"file_name": file.name}) + results.append(f"📄 {file.name}: {result}") + + elif file.name.lower().endswith('.zip'): + result = process_plain_text_zip_file( + zip_file=file, + author=author, + keywords=keywords, + system_prompt=system_prompt, + user_prompt=user_prompt, + auto_summarize=auto_summarize, + api_name=api_name, + api_key=api_key + ) + log_counter("zip_import_success", labels={"file_name": file.name}) + results.append(f"📦 {file.name}:\n{result}") + + else: + log_counter("unsupported_file_type", labels={"file_type": file.name.split('.')[-1]}) + results.append(f"❌ Unsupported file type: {file.name}") + continue + + end_time = datetime.now() + processing_time = (end_time - start_time).total_seconds() + log_histogram("plaintext_import_duration", processing_time, labels={"file_name": file.name}) + + return "\n\n".join(results) + + except Exception as e: + logging.exception("Error in import_file_handler") + log_counter("plaintext_import_error", labels={"error": str(e)}) + return f"❌ Error during import: {str(e)}" # # End of Plaintext_Files.py diff --git a/App_Function_Libraries/Prompt_Engineering/Prompt_Engineering.py b/App_Function_Libraries/Prompt_Engineering/Prompt_Engineering.py index d192f55b8..c037eeadb 100644 --- a/App_Function_Libraries/Prompt_Engineering/Prompt_Engineering.py +++ b/App_Function_Libraries/Prompt_Engineering/Prompt_Engineering.py @@ -4,7 +4,7 @@ # Imports import re -from App_Function_Libraries.Chat import chat_api_call +from App_Function_Libraries.Chat.Chat_Functions import chat_api_call # # Local Imports # diff --git a/App_Function_Libraries/Prompt_Engineering/__Init__.py b/App_Function_Libraries/Prompt_Engineering/__Init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/App_Function_Libraries/RAG/ChromaDB_Library.py b/App_Function_Libraries/RAG/ChromaDB_Library.py index 3cec5fd8d..623e5adb4 100644 --- a/App_Function_Libraries/RAG/ChromaDB_Library.py +++ b/App_Function_Libraries/RAG/ChromaDB_Library.py @@ -49,36 +49,37 @@ # Function to preprocess and store all existing content in the database -def preprocess_all_content(database, create_contextualized=True, api_name="gpt-3.5-turbo"): - unprocessed_media = get_unprocessed_media(db=database) - total_media = len(unprocessed_media) - - for index, row in enumerate(unprocessed_media, 1): - media_id, content, media_type, file_name = row - collection_name = f"{media_type}_{media_id}" - - logger.info(f"Processing media {index} of {total_media}: ID {media_id}, Type {media_type}") - - try: - process_and_store_content( - database=database, - content=content, - collection_name=collection_name, - media_id=media_id, - file_name=file_name or f"{media_type}_{media_id}", - create_embeddings=True, - create_contextualized=create_contextualized, - api_name=api_name - ) - - # Mark the media as processed in the database - mark_media_as_processed(database, media_id) - - logger.info(f"Successfully processed media ID {media_id}") - except Exception as e: - logger.error(f"Error processing media ID {media_id}: {str(e)}") - - logger.info("Finished preprocessing all unprocessed content") +# FIXME - Deprecated +# def preprocess_all_content(database, create_contextualized=True, api_name="gpt-3.5-turbo"): +# unprocessed_media = get_unprocessed_media(db=database) +# total_media = len(unprocessed_media) +# +# for index, row in enumerate(unprocessed_media, 1): +# media_id, content, media_type, file_name = row +# collection_name = f"{media_type}_{media_id}" +# +# logger.info(f"Processing media {index} of {total_media}: ID {media_id}, Type {media_type}") +# +# try: +# process_and_store_content( +# database=database, +# content=content, +# collection_name=collection_name, +# media_id=media_id, +# file_name=file_name or f"{media_type}_{media_id}", +# create_embeddings=True, +# create_contextualized=create_contextualized, +# api_name=api_name +# ) +# +# # Mark the media as processed in the database +# mark_media_as_processed(database, media_id) +# +# logger.info(f"Successfully processed media ID {media_id}") +# except Exception as e: +# logger.error(f"Error processing media ID {media_id}: {str(e)}") +# +# logger.info("Finished preprocessing all unprocessed content") def batched(iterable, n): @@ -233,7 +234,10 @@ def store_in_chroma(collection_name: str, texts: List[str], embeddings: Any, ids logging.info(f"Number of embeddings: {len(embeddings)}, Dimension: {embedding_dim}") try: - # Attempt to get or create the collection + # Clean metadata + cleaned_metadatas = [clean_metadata(metadata) for metadata in metadatas] + + # Try to get or create the collection try: collection = chroma_client.get_collection(name=collection_name) logging.info(f"Existing collection '{collection_name}' found") @@ -258,7 +262,7 @@ def store_in_chroma(collection_name: str, texts: List[str], embeddings: Any, ids documents=texts, embeddings=embeddings, ids=ids, - metadatas=metadatas + metadatas=cleaned_metadatas ) logging.info(f"Successfully upserted {len(embeddings)} embeddings") @@ -290,12 +294,19 @@ def vector_search(collection_name: str, query: str, k: int = 10) -> List[Dict[st # Fetch a sample of embeddings to check metadata sample_results = collection.get(limit=10, include=["metadatas"]) - if not sample_results['metadatas']: - raise ValueError("No metadata found in the collection") + if not sample_results.get('metadatas') or not any(sample_results['metadatas']): + logging.warning(f"No metadata found in the collection '{collection_name}'. Skipping this collection.") + return [] # Check if all embeddings use the same model and provider - embedding_models = [metadata.get('embedding_model') for metadata in sample_results['metadatas'] if metadata.get('embedding_model')] - embedding_providers = [metadata.get('embedding_provider') for metadata in sample_results['metadatas'] if metadata.get('embedding_provider')] + embedding_models = [ + metadata.get('embedding_model') for metadata in sample_results['metadatas'] + if metadata and metadata.get('embedding_model') + ] + embedding_providers = [ + metadata.get('embedding_provider') for metadata in sample_results['metadatas'] + if metadata and metadata.get('embedding_provider') + ] if not embedding_models or not embedding_providers: raise ValueError("Embedding model or provider information not found in metadata") @@ -319,13 +330,13 @@ def vector_search(collection_name: str, query: str, k: int = 10) -> List[Dict[st ) if not results['documents'][0]: - logging.warning("No results found for the query") + logging.warning(f"No results found for the query in collection '{collection_name}'.") return [] return [{"content": doc, "metadata": meta} for doc, meta in zip(results['documents'][0], results['metadatas'][0])] except Exception as e: - logging.error(f"Error in vector_search: {str(e)}", exc_info=True) - raise + logging.error(f"Error in vector_search for collection '{collection_name}': {str(e)}", exc_info=True) + return [] def schedule_embedding(media_id: int, content: str, media_name: str): @@ -350,6 +361,21 @@ def schedule_embedding(media_id: int, content: str, media_name: str): logging.error(f"Error scheduling embedding for media_id {media_id}: {str(e)}") +def clean_metadata(metadata: Dict[str, Any]) -> Dict[str, Any]: + """Clean metadata by removing None values and converting to appropriate types""" + cleaned = {} + for key, value in metadata.items(): + if value is not None: # Skip None values + if isinstance(value, (str, int, float, bool)): + cleaned[key] = value + elif isinstance(value, (np.int32, np.int64)): + cleaned[key] = int(value) + elif isinstance(value, (np.float32, np.float64)): + cleaned[key] = float(value) + else: + cleaned[key] = str(value) # Convert other types to string + return cleaned + # Function to process content, create chunks, embeddings, and store in ChromaDB and SQLite # def process_and_store_content(content: str, collection_name: str, media_id: int): # # Process the content into chunks diff --git a/App_Function_Libraries/RAG/Embeddings_Create.py b/App_Function_Libraries/RAG/Embeddings_Create.py index 0f57abbed..04ca6b09a 100644 --- a/App_Function_Libraries/RAG/Embeddings_Create.py +++ b/App_Function_Libraries/RAG/Embeddings_Create.py @@ -25,8 +25,6 @@ # # Functions: -# FIXME - Version 2 - # Load configuration loaded_config = load_comprehensive_config() embedding_provider = loaded_config['Embeddings']['embedding_provider'] @@ -331,177 +329,6 @@ def create_openai_embedding(text: str, model: str) -> List[float]: return embedding -# FIXME - Version 1 -# # FIXME - Add all globals to summarize.py -# loaded_config = load_comprehensive_config() -# embedding_provider = loaded_config['Embeddings']['embedding_provider'] -# embedding_model = loaded_config['Embeddings']['embedding_model'] -# embedding_api_url = loaded_config['Embeddings']['embedding_api_url'] -# embedding_api_key = loaded_config['Embeddings']['embedding_api_key'] -# -# # Embedding Chunking Settings -# chunk_size = loaded_config['Embeddings']['chunk_size'] -# overlap = loaded_config['Embeddings']['overlap'] -# -# -# # FIXME - Add logging -# -# class HuggingFaceEmbedder: -# def __init__(self, model_name, timeout_seconds=120): # Default timeout of 2 minutes -# self.model_name = model_name -# self.tokenizer = None -# self.model = None -# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -# self.timeout_seconds = timeout_seconds -# self.last_used_time = 0 -# self.unload_timer = None -# -# def load_model(self): -# if self.model is None: -# self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) -# self.model = AutoModel.from_pretrained(self.model_name) -# self.model.to(self.device) -# self.last_used_time = time.time() -# self.reset_timer() -# -# def unload_model(self): -# if self.model is not None: -# del self.model -# del self.tokenizer -# if torch.cuda.is_available(): -# torch.cuda.empty_cache() -# self.model = None -# self.tokenizer = None -# if self.unload_timer: -# self.unload_timer.cancel() -# -# def reset_timer(self): -# if self.unload_timer: -# self.unload_timer.cancel() -# self.unload_timer = Timer(self.timeout_seconds, self.unload_model) -# self.unload_timer.start() -# -# def create_embeddings(self, texts): -# self.load_model() -# inputs = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512) -# inputs = {k: v.to(self.device) for k, v in inputs.items()} -# with torch.no_grad(): -# outputs = self.model(**inputs) -# embeddings = outputs.last_hidden_state.mean(dim=1) -# return embeddings.cpu().numpy() -# -# # Global variable to hold the embedder -# huggingface_embedder = None -# -# -# class RateLimiter: -# def __init__(self, max_calls, period): -# self.max_calls = max_calls -# self.period = period -# self.calls = [] -# self.lock = Lock() -# -# def __call__(self, func): -# def wrapper(*args, **kwargs): -# with self.lock: -# now = time.time() -# self.calls = [call for call in self.calls if call > now - self.period] -# if len(self.calls) >= self.max_calls: -# sleep_time = self.calls[0] - (now - self.period) -# time.sleep(sleep_time) -# self.calls.append(time.time()) -# return func(*args, **kwargs) -# return wrapper -# -# -# def exponential_backoff(max_retries=5, base_delay=1): -# def decorator(func): -# @wraps(func) -# def wrapper(*args, **kwargs): -# for attempt in range(max_retries): -# try: -# return func(*args, **kwargs) -# except Exception as e: -# if attempt == max_retries - 1: -# raise -# delay = base_delay * (2 ** attempt) -# logging.warning(f"Attempt {attempt + 1} failed. Retrying in {delay} seconds. Error: {str(e)}") -# time.sleep(delay) -# return wrapper -# return decorator -# -# -# # FIXME - refactor/setup to use config file & perform chunking -# @exponential_backoff() -# @RateLimiter(max_calls=50, period=60) -# def create_embeddings_batch(texts: List[str], provider: str, model: str, api_url: str, timeout_seconds: int = 300) -> List[List[float]]: -# global embedding_models -# -# try: -# if provider.lower() == 'huggingface': -# if model not in embedding_models: -# if model == "dunzhang/stella_en_400M_v5": -# embedding_models[model] = ONNXEmbedder(model, model_dir, timeout_seconds) -# else: -# embedding_models[model] = HuggingFaceEmbedder(model, timeout_seconds) -# embedder = embedding_models[model] -# return embedder.create_embeddings(texts) -# -# elif provider.lower() == 'openai': -# logging.debug(f"Creating embeddings for {len(texts)} texts using OpenAI API") -# return [create_openai_embedding(text, model) for text in texts] -# -# elif provider.lower() == 'local': -# response = requests.post( -# api_url, -# json={"texts": texts, "model": model}, -# headers={"Authorization": f"Bearer {embedding_api_key}"} -# ) -# if response.status_code == 200: -# return response.json()['embeddings'] -# else: -# raise Exception(f"Error from local API: {response.text}") -# else: -# raise ValueError(f"Unsupported embedding provider: {provider}") -# except Exception as e: -# logging.error(f"Error in create_embeddings_batch: {str(e)}") -# raise -# -# def create_embedding(text: str, provider: str, model: str, api_url: str) -> List[float]: -# return create_embeddings_batch([text], provider, model, api_url)[0] -# -# -# def create_openai_embedding(text: str, model: str) -> List[float]: -# embedding = get_openai_embeddings(text, model) -# return embedding -# -# -# # FIXME - refactor to use onnx embeddings callout -# def create_stella_embeddings(text: str) -> List[float]: -# if embedding_provider == 'local': -# # Load the model and tokenizer -# tokenizer = AutoTokenizer.from_pretrained("dunzhang/stella_en_400M_v5") -# model = AutoModel.from_pretrained("dunzhang/stella_en_400M_v5") -# -# # Tokenize and encode the text -# inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) -# -# # Generate embeddings -# with torch.no_grad(): -# outputs = model(**inputs) -# -# # Use the mean of the last hidden state as the sentence embedding -# embeddings = outputs.last_hidden_state.mean(dim=1) -# -# return embeddings[0].tolist() # Convert to list for consistency -# elif embedding_provider == 'openai': -# return get_openai_embeddings(text, embedding_model) -# else: -# raise ValueError(f"Unsupported embedding provider: {embedding_provider}") -# # -# # End of F -# ############################################################## -# # # ############################################################## # # diff --git a/App_Function_Libraries/RAG/RAG_Library_2.py b/App_Function_Libraries/RAG/RAG_Library_2.py index 10c8c5dfc..ad80eef1e 100644 --- a/App_Function_Libraries/RAG/RAG_Library_2.py +++ b/App_Function_Libraries/RAG/RAG_Library_2.py @@ -9,14 +9,16 @@ from typing import Dict, Any, List, Optional from App_Function_Libraries.DB.Character_Chat_DB import get_character_chats, perform_full_text_search_chat, \ - fetch_keywords_for_chats + fetch_keywords_for_chats, search_character_chat, search_character_cards, fetch_character_ids_by_keywords +from App_Function_Libraries.DB.RAG_QA_Chat_DB import search_rag_chat, search_rag_notes # # Local Imports from App_Function_Libraries.RAG.ChromaDB_Library import process_and_store_content, vector_search, chroma_client from App_Function_Libraries.RAG.RAG_Persona_Chat import perform_vector_search_chat from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_custom_openai from App_Function_Libraries.Web_Scraping.Article_Extractor_Lib import scrape_article -from App_Function_Libraries.DB.DB_Manager import search_db, fetch_keywords_for_media +from App_Function_Libraries.DB.DB_Manager import fetch_keywords_for_media, search_media_db, get_notes_by_keywords, \ + search_conversations_by_keywords from App_Function_Libraries.Utils.Utils import load_comprehensive_config from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram # @@ -40,6 +42,15 @@ # Read the configuration file config.read('config.txt') + +search_functions = { + "Media DB": search_media_db, + "RAG Chat": search_rag_chat, + "RAG Notes": search_rag_notes, + "Character Chat": search_character_chat, + "Character Cards": search_character_cards +} + # RAG pipeline function for web scraping # def rag_web_scraping_pipeline(url: str, query: str, api_choice=None) -> Dict[str, Any]: # try: @@ -117,7 +128,20 @@ # RAG Search with keyword filtering # FIXME - Update each called function to support modifiable top-k results -def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, top_k=10, apply_re_ranking=True) -> Dict[str, Any]: +def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts_top_k=10, apply_re_ranking=True, database_types: List[str] = "Media DB") -> Dict[str, Any]: + """ + Perform full text search across specified database type. + + Args: + query: Search query string + api_choice: API to use for generating the response + fts_top_k: Maximum number of results to return + keywords: Optional list of media IDs to filter results + database_types: Type of database to search ("Media DB", "RAG Chat", or "Character Chat") + + Returns: + Dictionary containing search results with content + """ log_counter("enhanced_rag_pipeline_attempt", labels={"api_choice": api_choice}) start_time = time.time() try: @@ -131,16 +155,97 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, top keyword_list = [k.strip().lower() for k in keywords.split(',')] if keywords else [] logging.debug(f"\n\nenhanced_rag_pipeline - Keywords: {keyword_list}") - # Fetch relevant media IDs based on keywords if keywords are provided - relevant_media_ids = fetch_relevant_media_ids(keyword_list) if keyword_list else None - logging.debug(f"\n\nenhanced_rag_pipeline - relevant media IDs: {relevant_media_ids}") + relevant_ids = {} + + # Fetch relevant IDs based on keywords if keywords are provided + if keyword_list: + try: + for db_type in database_types: + if db_type == "Media DB": + relevant_media_ids = fetch_relevant_media_ids(keyword_list) + relevant_ids[db_type] = relevant_media_ids + logging.debug(f"enhanced_rag_pipeline - {db_type} relevant media IDs: {relevant_media_ids}") + + elif db_type == "RAG Chat": + conversations, total_pages, total_count = search_conversations_by_keywords( + keywords=keyword_list) + relevant_conversation_ids = [conv['conversation_id'] for conv in conversations] + relevant_ids[db_type] = relevant_conversation_ids + logging.debug( + f"enhanced_rag_pipeline - {db_type} relevant conversation IDs: {relevant_conversation_ids}") + + elif db_type == "RAG Notes": + notes, total_pages, total_count = get_notes_by_keywords(keyword_list) + relevant_note_ids = [note_id for note_id, _, _, _ in notes] # Unpack note_id from the tuple + relevant_ids[db_type] = relevant_note_ids + logging.debug(f"enhanced_rag_pipeline - {db_type} relevant note IDs: {relevant_note_ids}") + + elif db_type == "Character Chat": + relevant_chat_ids = fetch_keywords_for_chats(keyword_list) + relevant_ids[db_type] = relevant_chat_ids + logging.debug(f"enhanced_rag_pipeline - {db_type} relevant chat IDs: {relevant_chat_ids}") + + elif db_type == "Character Cards": + # Assuming we have a function to fetch character IDs by keywords + relevant_character_ids = fetch_character_ids_by_keywords(keyword_list) + relevant_ids[db_type] = relevant_character_ids + logging.debug( + f"enhanced_rag_pipeline - {db_type} relevant character IDs: {relevant_character_ids}") + + else: + logging.error(f"Unsupported database type: {db_type}") + + except Exception as e: + logging.error(f"Error fetching relevant IDs: {str(e)}") + else: + relevant_ids = None + + # Extract relevant media IDs for each selected DB + # Prepare a dict to hold relevant_media_ids per DB + relevant_media_ids_dict = {} + if relevant_ids: + for db_type in database_types: + relevant_media_ids = relevant_ids.get(db_type, None) + if relevant_media_ids: + # Convert to List[str] if not None + relevant_media_ids_dict[db_type] = [str(media_id) for media_id in relevant_media_ids] + else: + relevant_media_ids_dict[db_type] = None + else: + relevant_media_ids_dict = {db_type: None for db_type in database_types} + + # Perform vector search for all selected databases + vector_results = [] + for db_type in database_types: + try: + db_relevant_ids = relevant_media_ids_dict.get(db_type) + results = perform_vector_search(query, db_relevant_ids, top_k=fts_top_k) + vector_results.extend(results) + logging.debug(f"\nenhanced_rag_pipeline - Vector search results for {db_type}: {results}") + except Exception as e: + logging.error(f"Error performing vector search on {db_type}: {str(e)}") # Perform vector search + # FIXME vector_results = perform_vector_search(query, relevant_media_ids) logging.debug(f"\n\nenhanced_rag_pipeline - Vector search results: {vector_results}") # Perform full-text search - fts_results = perform_full_text_search(query, relevant_media_ids) + #v1 + #fts_results = perform_full_text_search(query, database_type, relevant_media_ids, fts_top_k) + + # v2 + # Perform full-text search across specified databases + fts_results = [] + for db_type in database_types: + try: + db_relevant_ids = relevant_ids.get(db_type) if relevant_ids else None + db_results = perform_full_text_search(query, db_type, db_relevant_ids, fts_top_k) + fts_results.extend(db_results) + logging.debug(f"enhanced_rag_pipeline - FTS results for {db_type}: {db_results}") + except Exception as e: + logging.error(f"Error performing full-text search on {db_type}: {str(e)}") + logging.debug("\n\nenhanced_rag_pipeline - Full-text search results:") logging.debug( "\n\nenhanced_rag_pipeline - Full-text search results:\n" + "\n".join( @@ -175,8 +280,8 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, top # Update all_results based on reranking all_results = [all_results[result['id']] for result in reranked_results] - # Extract content from results (top 10 by default) - context = "\n".join([result['content'] for result in all_results[:top_k]]) + # Extract content from results (top fts_top_k by default) + context = "\n".join([result['content'] for result in all_results[:fts_top_k]]) logging.debug(f"Context length: {len(context)}") logging.debug(f"Context: {context[:200]}") @@ -208,6 +313,8 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, top "context": "" } + + # Need to write a test for this function FIXME def generate_answer(api_choice: str, context: str, query: str) -> str: # Metrics @@ -336,6 +443,7 @@ def generate_answer(api_choice: str, context: str, query: str) -> str: logging.error(f"Error in generate_answer: {str(e)}") return "An error occurred while generating the answer." + def perform_vector_search(query: str, relevant_media_ids: List[str] = None, top_k=10) -> List[Dict[str, Any]]: log_counter("perform_vector_search_attempt") start_time = time.time() @@ -344,6 +452,8 @@ def perform_vector_search(query: str, relevant_media_ids: List[str] = None, top_ try: for collection in all_collections: collection_results = vector_search(collection.name, query, k=top_k) + if not collection_results: + continue # Skip empty results filtered_results = [ result for result in collection_results if relevant_media_ids is None or result['metadata'].get('media_id') in relevant_media_ids @@ -358,29 +468,75 @@ def perform_vector_search(query: str, relevant_media_ids: List[str] = None, top_ logging.error(f"Error in perform_vector_search: {str(e)}") raise -def perform_full_text_search(query: str, relevant_media_ids: List[str] = None, fts_top_k=None) -> List[Dict[str, Any]]: - log_counter("perform_full_text_search_attempt") + +# V2 +def perform_full_text_search(query: str, database_type: str, relevant_ids: List[str] = None, fts_top_k=None) -> List[Dict[str, Any]]: + """ + Perform full-text search on a specified database type. + + Args: + query: Search query string + database_type: Type of database to search ("Media DB", "RAG Chat", "RAG Notes", "Character Chat", "Character Cards") + relevant_ids: Optional list of media IDs to filter results + fts_top_k: Maximum number of results to return + + Returns: + List of search results with content and metadata + """ + log_counter("perform_full_text_search_attempt", labels={"database_type": database_type}) start_time = time.time() + try: - fts_results = search_db(query, ["content"], "", page=1, results_per_page=fts_top_k or 10) - filtered_fts_results = [ - { - "content": result['content'], - "metadata": {"media_id": result['id']} - } - for result in fts_results - if relevant_media_ids is None or result['id'] in relevant_media_ids - ] + # Set default for fts_top_k + if fts_top_k is None: + fts_top_k = 10 + + # Call appropriate search function based on database type + if database_type not in search_functions: + raise ValueError(f"Unsupported database type: {database_type}") + + # Call the appropriate search function + results = search_functions[database_type](query, fts_top_k, relevant_ids) + search_duration = time.time() - start_time - log_histogram("perform_full_text_search_duration", search_duration) - log_counter("perform_full_text_search_success", labels={"result_count": len(filtered_fts_results)}) - return filtered_fts_results + log_histogram("perform_full_text_search_duration", search_duration, + labels={"database_type": database_type}) + log_counter("perform_full_text_search_success", + labels={"database_type": database_type, "result_count": len(results)}) + + return results + except Exception as e: - log_counter("perform_full_text_search_error", labels={"error": str(e)}) - logging.error(f"Error in perform_full_text_search: {str(e)}") + log_counter("perform_full_text_search_error", + labels={"database_type": database_type, "error": str(e)}) + logging.error(f"Error in perform_full_text_search ({database_type}): {str(e)}") raise +# v1 +# def perform_full_text_search(query: str, relevant_media_ids: List[str] = None, fts_top_k=None) -> List[Dict[str, Any]]: +# log_counter("perform_full_text_search_attempt") +# start_time = time.time() +# try: +# fts_results = search_db(query, ["content"], "", page=1, results_per_page=fts_top_k or 10) +# filtered_fts_results = [ +# { +# "content": result['content'], +# "metadata": {"media_id": result['id']} +# } +# for result in fts_results +# if relevant_media_ids is None or result['id'] in relevant_media_ids +# ] +# search_duration = time.time() - start_time +# log_histogram("perform_full_text_search_duration", search_duration) +# log_counter("perform_full_text_search_success", labels={"result_count": len(filtered_fts_results)}) +# return filtered_fts_results +# except Exception as e: +# log_counter("perform_full_text_search_error", labels={"error": str(e)}) +# logging.error(f"Error in perform_full_text_search: {str(e)}") +# raise + + def fetch_relevant_media_ids(keywords: List[str], top_k=10) -> List[int]: log_counter("fetch_relevant_media_ids_attempt", labels={"keyword_count": len(keywords)}) start_time = time.time() @@ -502,6 +658,7 @@ def enhanced_rag_pipeline_chat(query: str, api_choice: str, character_id: int, k logging.debug(f"enhanced_rag_pipeline_chat - Vector search results: {vector_results}") # Perform full-text search within the relevant chats + # FIXME - Update for DB Selection fts_results = perform_full_text_search_chat(query, relevant_chat_ids) logging.debug("enhanced_rag_pipeline_chat - Full-text search results:") logging.debug("\n".join([str(item) for item in fts_results])) diff --git a/App_Function_Libraries/RAG/RAG_QA_Chat.py b/App_Function_Libraries/RAG/RAG_QA_Chat.py index 5ec3af076..1c6580c06 100644 --- a/App_Function_Libraries/RAG/RAG_QA_Chat.py +++ b/App_Function_Libraries/RAG/RAG_QA_Chat.py @@ -12,7 +12,7 @@ from typing import List, Tuple, IO, Union # # Local Imports -from App_Function_Libraries.DB.DB_Manager import db, search_db, DatabaseError, get_media_content +from App_Function_Libraries.DB.DB_Manager import db, search_media_db, DatabaseError, get_media_content from App_Function_Libraries.RAG.RAG_Library_2 import generate_answer, enhanced_rag_pipeline from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram # @@ -89,7 +89,7 @@ def search_database(query: str) -> List[Tuple[int, str]]: log_counter("search_database_attempt") start_time = time.time() # Implement database search functionality - results = search_db(query, ["title", "content"], "", page=1, results_per_page=10) + results = search_media_db(query, ["title", "content"], "", page=1, results_per_page=10) search_duration = time.time() - start_time log_histogram("search_database_duration", search_duration) log_counter("search_database_success", labels={"result_count": len(results)}) diff --git a/App_Function_Libraries/Utils/Utils.py b/App_Function_Libraries/Utils/Utils.py index 294571455..5e2e4f0db 100644 --- a/App_Function_Libraries/Utils/Utils.py +++ b/App_Function_Libraries/Utils/Utils.py @@ -95,8 +95,6 @@ def cleanup_downloads(): ####################################################################################################################### # Config loading # - - def load_comprehensive_config(): # Get the directory of the current script (Utils.py) current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -126,25 +124,33 @@ def load_comprehensive_config(): def get_project_root(): - # Get the directory of the current file (Utils.py) + """Get the absolute path to the project root directory.""" current_dir = os.path.dirname(os.path.abspath(__file__)) - # Go up two levels to reach the project root - # Assuming the structure is: project_root/App_Function_Libraries/Utils/Utils.py project_root = os.path.dirname(os.path.dirname(current_dir)) + logging.debug(f"Project root: {project_root}") return project_root + def get_database_dir(): - """Get the database directory (/tldw/Databases/).""" + """Get the absolute path to the database directory.""" db_dir = os.path.join(get_project_root(), 'Databases') + os.makedirs(db_dir, exist_ok=True) logging.debug(f"Database directory: {db_dir}") return db_dir -def get_database_path(db_name: Union[str, os.PathLike[AnyStr]]) -> str: - """Get the full path for a database file.""" - path = os.path.join(get_database_dir(), str(db_name)) - logging.debug(f"Database path for {db_name}: {path}") + +def get_database_path(db_name: str) -> str: + """ + Get the full absolute path for a database file. + Ensures the path is always within the Databases directory. + """ + # Remove any directory traversal attempts + safe_db_name = os.path.basename(db_name) + path = os.path.join(get_database_dir(), safe_db_name) + logging.debug(f"Database path for {safe_db_name}: {path}") return path + def get_project_relative_path(relative_path: Union[str, os.PathLike[AnyStr]]) -> str: """Convert a relative path to a path relative to the project root.""" path = os.path.join(get_project_root(), str(relative_path)) @@ -280,6 +286,10 @@ def load_and_log_configs(): # Prompts - FIXME prompt_path = config.get('Prompts', 'prompt_path', fallback='Databases/prompts.db') + # Auto-Save Values + save_character_chats = config.get('Auto-Save', 'save_character_chats', fallback='False') + save_rag_chats = config.get('Auto-Save', 'save_rag_chats', fallback='False') + return { 'api_keys': { 'anthropic': anthropic_api_key, @@ -343,6 +353,10 @@ def load_and_log_configs(): 'chunk_size': chunk_size, 'overlap': overlap }, + 'auto-save': { + 'save_character_chats': save_character_chats, + 'save_rag_chats': save_rag_chats, + }, 'default_api': default_api } diff --git a/Config_Files/Backup_Config.txt b/Config_Files/Backup_Config.txt index 46d24c7cf..3d497b8b8 100644 --- a/Config_Files/Backup_Config.txt +++ b/Config_Files/Backup_Config.txt @@ -17,23 +17,24 @@ mistral_model = mistral-large-latest mistral_api_key =