From 06f9905c89431e667be26942687973a96566e8f6 Mon Sep 17 00:00:00 2001 From: Jay Rodge <jrodge@nvidia.com> Date: Thu, 29 Aug 2024 14:20:35 -0800 Subject: [PATCH] Add Multimodal RAG (llamaindex+NIMs) example to community projects (#178) * Add Multimodal RAG example to community projects * added instruction to setup Milvus GPU docker * added copyright headers --- .../llamaindex/llamaindex_basic_RAG.ipynb | 10 +- community/README.md | 2 +- community/multimodal-rag/README.md | 109 +++++++ community/multimodal-rag/app.py | 118 ++++++++ .../multimodal-rag/document_processors.py | 284 ++++++++++++++++++ community/multimodal-rag/requirements.txt | 13 + community/multimodal-rag/utils.py | 167 ++++++++++ community/multimodal-rag/vectorstore/.gitkeep | 0 .../vectorstore/image_references/.gitkeep | 0 .../vectorstore/ppt_references/.gitkeep | 0 .../vectorstore/table_references/.gitkeep | 0 11 files changed, 697 insertions(+), 6 deletions(-) create mode 100644 community/multimodal-rag/README.md create mode 100644 community/multimodal-rag/app.py create mode 100644 community/multimodal-rag/document_processors.py create mode 100644 community/multimodal-rag/requirements.txt create mode 100644 community/multimodal-rag/utils.py create mode 100644 community/multimodal-rag/vectorstore/.gitkeep create mode 100644 community/multimodal-rag/vectorstore/image_references/.gitkeep create mode 100644 community/multimodal-rag/vectorstore/ppt_references/.gitkeep create mode 100644 community/multimodal-rag/vectorstore/table_references/.gitkeep diff --git a/RAG/notebooks/llamaindex/llamaindex_basic_RAG.ipynb b/RAG/notebooks/llamaindex/llamaindex_basic_RAG.ipynb index 59ecd1f3..30e709bb 100644 --- a/RAG/notebooks/llamaindex/llamaindex_basic_RAG.ipynb +++ b/RAG/notebooks/llamaindex/llamaindex_basic_RAG.ipynb @@ -12,11 +12,11 @@ "cell_type": "markdown", "id": "2969cdab-82fc-4ce5-bde1-b4f629691f27", "metadata": {}, - "source": [ - "This notebook introduces how to use LlamaIndex to interact with NVIDIA hosted NIM microservices like chat, embedding, and reranking models to build a simple retrieval-augmented generation (RAG) application.\n", - "\n", - "Alternatively, for a more interactive experience with a graphical user interface, you can refer to our Gradio-based RAG Q&A reference application that also uses NVIDIA hosted NIM microservices [here](https://github.com/jayrodge/llm-assistant-cloud-app/)." - ] + "source": [ + "This notebook introduces how to use LlamaIndex to interact with NVIDIA hosted NIM microservices like chat, embedding, and reranking models to build a simple retrieval-augmented generation (RAG) application.\n", + "\n", + "Alternatively, for a more interactive experience with a graphical user interface, you can refer to our [code](https://github.com/jayrodge/llm-assistant-cloud-app/) and [YouTube video](https://www.youtube.com/watch?v=09uDCmLzYHA) for Gradio-based RAG Q&A reference application that also uses NVIDIA hosted NIM microservices." + ] }, { "cell_type": "markdown", diff --git a/community/README.md b/community/README.md index 6a416abf..5d7c894e 100644 --- a/community/README.md +++ b/community/README.md @@ -45,7 +45,7 @@ Community examples are sample code and deployments for RAG pipelines that are no * [NVIDIA Multimodal RAG Assistant](./multimodal_assistant) - This example is able to ingest PDFs, PowerPoint slides, Word and other documents with complex data formats including text, images, slides and tables. It allows users to ask questions through a text interface and optionally with an image query, and it can respond with text and reference images, slides and tables in its response, along with source links and downloads. + This example is able to ingest PDFs, PowerPoint slides, Word and other documents with complex data formats including text, images, slides and tables, orchestrated with Langchain. It allows users to ask questions through a text interface and optionally with an image query, and it can respond with text and reference images, slides and tables in its response, along with source links and downloads. Refer to this [example](./multimodal-rag) for the LlamaIndex version that uses [integration](https://docs.llamaindex.ai/en/stable/examples/llm/nvidia_nim/) with NVIDIA Inference Microservices (NIMs) of the Multimodal RAG Assistant. * [NVIDIA Developer RAG Chatbot](./rag-developer-chatbot) diff --git a/community/multimodal-rag/README.md b/community/multimodal-rag/README.md new file mode 100644 index 00000000..47e583cc --- /dev/null +++ b/community/multimodal-rag/README.md @@ -0,0 +1,109 @@ +# Creating Multimodal AI Agent for Enhanced Content Understanding + +## Overview + +This Streamlit application implements a Multimodal Retrieval-Augmented Generation (RAG) system. It processes various types of documents including text files, PDFs, PowerPoint presentations, and images. The app leverages Large Language Models and Vision Language Models to extract and index information from these documents, allowing users to query the processed data through an interactive chat interface. + +The system utilizes LlamaIndex for efficient indexing and retrieval of information, NVIDIA Inference Microservices (NIMs) for high-performance inference capabilities, and Milvus as a vector database for efficient storage and retrieval of embedding vectors. This combination of technologies enables the application to handle complex multimodal data, perform advanced queries, and deliver rapid, context-aware responses to user inquiries. + +## Features + +- **Multi-format Document Processing**: Handles text files, PDFs, PowerPoint presentations, and images. +- **Advanced Text Extraction**: Extracts text from PDFs and PowerPoint slides, including tables and embedded images. +- **Image Analysis**: Uses a VLM (NeVA) to describe images and Google's DePlot for processing graphs/charts on NVIDIA Inference Microservices (NIMs). +- **Vector Store Indexing**: Creates a searchable index of processed documents using Milvus vector store. +- **Interactive Chat Interface**: Allows users to query the processed information through a chat-like interface. + +## Setup + +1. Clone the repository: +``` +git clone https://github.com/NVIDIA/GenerativeAIExamples.git +cd GenerativeAIExamples/community/multimodal_rag +``` + +2. (Optional) Create a conda environment or a virtual environment: + + - Using conda: + ``` + conda create --name multimodal-rag python=3.10 + conda activate multimodal-rag + ``` + + - Using venv: + ``` + python -m venv venv + source venv/bin/activate + +3. Install the required packages: +``` +pip install -r requirements.txt +``` + +4. Set up your NVIDIA API key as an environment variable: +``` +export NVIDIA_API_KEY="your-api-key-here" +``` + +5. Refer this [tutorial](https://milvus.io/docs/install_standalone-docker-compose-gpu.md) to install and start the GPU-accelerated Milvus container: + +``` +sudo docker compose up -d +``` + + +## Usage + +1. Ensure the Milvus container is running: + +```bash +docker ps +``` + +2. Run the Streamlit app: +``` +streamlit run app.py +``` + +3. Open the provided URL in your web browser. + +4. Choose between uploading files or specifying a directory path containing your documents. + +5. Process the files by clicking the "Process Files" or "Process Directory" button. + +6. Once processing is complete, use the chat interface to query your documents. + +## File Structure + +- `app.py`: Main Streamlit application +- `utils.py`: Utility functions for image processing and API interactions +- `document_processors.py`: Functions for processing various document types +- `requirements.txt`: List of Python dependencies +- `vectorstore/` : Repository to store information from pdfs and ppt + + +## GPU Acceleration for Vector Search +To utilize GPU acceleration in the vector database, ensure that: +1. Your system has a compatible NVIDIA GPU. +2. You're using the GPU-enabled version of Milvus (as shown in the setup instructions). +3. There are enough concurrent requests to justify GPU usage. GPU acceleration typically shows significant benefits under high load conditions. + +It's important to note that GPU acceleration will only be used when the incoming requests are extremely high. For more detailed information on GPU indexing and search in Milvus, refer to the [official Milvus GPU Index documentation](https://milvus.io/docs/gpu_index.md). + +To connect the GPU-accelerated Milvus with LlamaIndex, update the MilvusVectorStore configuration in app.py: +``` +vector_store = MilvusVectorStore( + host="127.0.0.1", + port=19530, + dim=1024, + collection_name="your_collection_name", + gpu_id=0 # Specify the GPU ID to use +) +``` + +## Contributing +Contributions to this project are welcome! Please follow these steps: +1. Fork the NVIDIA/GenerativeAIExamples repository. +2. Create a new branch for your feature or bug fix. +3. Make your changes in the community/multimodal_rag/ directory. +4. Submit a pull request to the main repository. \ No newline at end of file diff --git a/community/multimodal-rag/app.py b/community/multimodal-rag/app.py new file mode 100644 index 00000000..e99c5da9 --- /dev/null +++ b/community/multimodal-rag/app.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import streamlit as st +from llama_index.core import Settings +from llama_index.core import VectorStoreIndex, StorageContext +from llama_index.core.node_parser import SentenceSplitter +from llama_index.vector_stores.milvus import MilvusVectorStore +from llama_index.embeddings.nvidia import NVIDIAEmbedding +from llama_index.llms.nvidia import NVIDIA + +from document_processors import load_multimodal_data, load_data_from_directory +from utils import set_environment_variables + +# Set up the page configuration +st.set_page_config(layout="wide") + +# Initialize settings +def initialize_settings(): + Settings.embed_model = NVIDIAEmbedding(model="NV-Embed-QA", truncate="END") + Settings.llm = NVIDIA(model="meta/llama-3.1-70b-instruct") + Settings.text_splitter = SentenceSplitter(chunk_size=600) + +# Create index from documents +def create_index(documents): + vector_store = MilvusVectorStore( + host = "127.0.0.1", + port = 19530, + dim = 1024 + ) + # vector_store = MilvusVectorStore(uri="./milvus_demo.db", dim=1024, overwrite=True) #For CPU only vector store + storage_context = StorageContext.from_defaults(vector_store=vector_store) + return VectorStoreIndex.from_documents(documents, storage_context=storage_context) + +# Main function to run the Streamlit app +def main(): + set_environment_variables() + initialize_settings() + + col1, col2 = st.columns([1, 2]) + + with col1: + st.title("Multimodal RAG") + + input_method = st.radio("Choose input method:", ("Upload Files", "Enter Directory Path")) + + if input_method == "Upload Files": + uploaded_files = st.file_uploader("Drag and drop files here", accept_multiple_files=True) + if uploaded_files and st.button("Process Files"): + with st.spinner("Processing files..."): + documents = load_multimodal_data(uploaded_files) + st.session_state['index'] = create_index(documents) + st.session_state['history'] = [] + st.success("Files processed and index created!") + else: + directory_path = st.text_input("Enter directory path:") + if directory_path and st.button("Process Directory"): + if os.path.isdir(directory_path): + with st.spinner("Processing directory..."): + documents = load_data_from_directory(directory_path) + st.session_state['index'] = create_index(documents) + st.session_state['history'] = [] + st.success("Directory processed and index created!") + else: + st.error("Invalid directory path. Please enter a valid path.") + + with col2: + if 'index' in st.session_state: + st.title("Chat") + if 'history' not in st.session_state: + st.session_state['history'] = [] + + query_engine = st.session_state['index'].as_query_engine(similarity_top_k=20, streaming=True) + + user_input = st.chat_input("Enter your query:") + + # Display chat messages + chat_container = st.container() + with chat_container: + for message in st.session_state['history']: + with st.chat_message(message["role"]): + st.markdown(message["content"]) + + if user_input: + with st.chat_message("user"): + st.markdown(user_input) + st.session_state['history'].append({"role": "user", "content": user_input}) + + with st.chat_message("assistant"): + message_placeholder = st.empty() + full_response = "" + response = query_engine.query(user_input) + for token in response.response_gen: + full_response += token + message_placeholder.markdown(full_response + "▌") + message_placeholder.markdown(full_response) + st.session_state['history'].append({"role": "assistant", "content": full_response}) + + # Add a clear button + if st.button("Clear Chat"): + st.session_state['history'] = [] + st.rerun() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/community/multimodal-rag/document_processors.py b/community/multimodal-rag/document_processors.py new file mode 100644 index 00000000..660de3cf --- /dev/null +++ b/community/multimodal-rag/document_processors.py @@ -0,0 +1,284 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import fitz +from pptx import Presentation +import subprocess +from llama_index.core import Document +from utils import ( + describe_image, is_graph, process_graph, extract_text_around_item, + process_text_blocks, save_uploaded_file +) + +def get_pdf_documents(pdf_file): + """Process a PDF file and extract text, tables, and images.""" + all_pdf_documents = [] + ongoing_tables = {} + + try: + f = fitz.open(stream=pdf_file.read(), filetype="pdf") + except Exception as e: + print(f"Error opening or processing the PDF file: {e}") + return [] + + for i in range(len(f)): + page = f[i] + text_blocks = [block for block in page.get_text("blocks", sort=True) + if block[-1] == 0 and not (block[1] < page.rect.height * 0.1 or block[3] > page.rect.height * 0.9)] + grouped_text_blocks = process_text_blocks(text_blocks) + + table_docs, table_bboxes, ongoing_tables = parse_all_tables(pdf_file.name, page, i, text_blocks, ongoing_tables) + all_pdf_documents.extend(table_docs) + + image_docs = parse_all_images(pdf_file.name, page, i, text_blocks) + all_pdf_documents.extend(image_docs) + + for text_block_ctr, (heading_block, content) in enumerate(grouped_text_blocks, 1): + heading_bbox = fitz.Rect(heading_block[:4]) + if not any(heading_bbox.intersects(table_bbox) for table_bbox in table_bboxes): + bbox = {"x1": heading_block[0], "y1": heading_block[1], "x2": heading_block[2], "x3": heading_block[3]} + text_doc = Document( + text=f"{heading_block[4]}\n{content}", + metadata={ + **bbox, + "type": "text", + "page_num": i, + "source": f"{pdf_file.name[:-4]}-page{i}-block{text_block_ctr}" + }, + id_=f"{pdf_file.name[:-4]}-page{i}-block{text_block_ctr}" + ) + all_pdf_documents.append(text_doc) + + f.close() + return all_pdf_documents + +def parse_all_tables(filename, page, pagenum, text_blocks, ongoing_tables): + """Extract tables from a PDF page.""" + table_docs = [] + table_bboxes = [] + try: + tables = page.find_tables(horizontal_strategy="lines_strict", vertical_strategy="lines_strict") + for tab in tables: + if not tab.header.external: + pandas_df = tab.to_pandas() + tablerefdir = os.path.join(os.getcwd(), "vectorstore/table_references") + os.makedirs(tablerefdir, exist_ok=True) + df_xlsx_path = os.path.join(tablerefdir, f"table{len(table_docs)+1}-page{pagenum}.xlsx") + pandas_df.to_excel(df_xlsx_path) + bbox = fitz.Rect(tab.bbox) + table_bboxes.append(bbox) + + before_text, after_text = extract_text_around_item(text_blocks, bbox, page.rect.height) + + table_img = page.get_pixmap(clip=bbox) + table_img_path = os.path.join(tablerefdir, f"table{len(table_docs)+1}-page{pagenum}.jpg") + table_img.save(table_img_path) + description = process_graph(table_img.tobytes()) + + caption = before_text.replace("\n", " ") + description + after_text.replace("\n", " ") + if before_text == "" and after_text == "": + caption = " ".join(tab.header.names) + table_metadata = { + "source": f"{filename[:-4]}-page{pagenum}-table{len(table_docs)+1}", + "dataframe": df_xlsx_path, + "image": table_img_path, + "caption": caption, + "type": "table", + "page_num": pagenum + } + all_cols = ", ".join(list(pandas_df.columns.values)) + doc = Document(text=f"This is a table with the caption: {caption}\nThe columns are {all_cols}", metadata=table_metadata) + table_docs.append(doc) + except Exception as e: + print(f"Error during table extraction: {e}") + return table_docs, table_bboxes, ongoing_tables + +def parse_all_images(filename, page, pagenum, text_blocks): + """Extract images from a PDF page.""" + image_docs = [] + image_info_list = page.get_image_info(xrefs=True) + page_rect = page.rect + + for image_info in image_info_list: + xref = image_info['xref'] + if xref == 0: + continue + + img_bbox = fitz.Rect(image_info['bbox']) + if img_bbox.width < page_rect.width / 20 or img_bbox.height < page_rect.height / 20: + continue + + extracted_image = page.parent.extract_image(xref) + image_data = extracted_image["image"] + imgrefpath = os.path.join(os.getcwd(), "vectorstore/image_references") + os.makedirs(imgrefpath, exist_ok=True) + image_path = os.path.join(imgrefpath, f"image{xref}-page{pagenum}.png") + with open(image_path, "wb") as img_file: + img_file.write(image_data) + + before_text, after_text = extract_text_around_item(text_blocks, img_bbox, page.rect.height) + if before_text == "" and after_text == "": + continue + + image_description = " " + if is_graph(image_data): + image_description = process_graph(image_data) + + caption = before_text.replace("\n", " ") + image_description + after_text.replace("\n", " ") + + image_metadata = { + "source": f"{filename[:-4]}-page{pagenum}-image{xref}", + "image": image_path, + "caption": caption, + "type": "image", + "page_num": pagenum + } + image_docs.append(Document(text="This is an image with the caption: " + caption, metadata=image_metadata)) + return image_docs + +def process_ppt_file(ppt_path): + """Process a PowerPoint file.""" + pdf_path = convert_ppt_to_pdf(ppt_path) + images_data = convert_pdf_to_images(pdf_path) + slide_texts = extract_text_and_notes_from_ppt(ppt_path) + processed_data = [] + + for (image_path, page_num), (slide_text, notes) in zip(images_data, slide_texts): + if notes: + notes = "\n\nThe speaker notes for this slide are: " + notes + + with open(image_path, 'rb') as image_file: + image_content = image_file.read() + + image_description = " " + if is_graph(image_content): + image_description = process_graph(image_content) + + image_metadata = { + "source": f"{os.path.basename(ppt_path)}", + "image": image_path, + "caption": slide_text + image_description + notes, + "type": "image", + "page_num": page_num + } + processed_data.append(Document(text="This is a slide with the text: " + slide_text + image_description, metadata=image_metadata)) + + return processed_data + +def convert_ppt_to_pdf(ppt_path): + """Convert a PowerPoint file to PDF using LibreOffice.""" + base_name = os.path.basename(ppt_path) + ppt_name_without_ext = os.path.splitext(base_name)[0].replace(' ', '_') + new_dir_path = os.path.abspath("vectorstore/ppt_references") + os.makedirs(new_dir_path, exist_ok=True) + pdf_path = os.path.join(new_dir_path, f"{ppt_name_without_ext}.pdf") + command = ['libreoffice', '--headless', '--convert-to', 'pdf', '--outdir', new_dir_path, ppt_path] + subprocess.run(command, check=True) + return pdf_path + +def convert_pdf_to_images(pdf_path): + """Convert a PDF file to a series of images using PyMuPDF.""" + doc = fitz.open(pdf_path) + base_name = os.path.basename(pdf_path) + pdf_name_without_ext = os.path.splitext(base_name)[0].replace(' ', '_') + new_dir_path = os.path.join(os.getcwd(), "vectorstore/ppt_references") + os.makedirs(new_dir_path, exist_ok=True) + image_paths = [] + + for page_num in range(len(doc)): + page = doc.load_page(page_num) + pix = page.get_pixmap() + output_image_path = os.path.join(new_dir_path, f"{pdf_name_without_ext}_{page_num:04d}.png") + pix.save(output_image_path) + image_paths.append((output_image_path, page_num)) + doc.close() + return image_paths + +def extract_text_and_notes_from_ppt(ppt_path): + """Extract text and notes from a PowerPoint file.""" + prs = Presentation(ppt_path) + text_and_notes = [] + for slide in prs.slides: + slide_text = ' '.join([shape.text for shape in slide.shapes if hasattr(shape, "text")]) + try: + notes = slide.notes_slide.notes_text_frame.text if slide.notes_slide else '' + except: + notes = '' + text_and_notes.append((slide_text, notes)) + return text_and_notes + +def load_multimodal_data(files): + """Load and process multiple file types.""" + documents = [] + for file in files: + file_extension = os.path.splitext(file.name.lower())[1] + if file_extension in ('.png', '.jpg', '.jpeg'): + image_content = file.read() + image_text = describe_image(image_content) + doc = Document(text=image_text, metadata={"source": file.name, "type": "image"}) + documents.append(doc) + elif file_extension == '.pdf': + try: + pdf_documents = get_pdf_documents(file) + documents.extend(pdf_documents) + except Exception as e: + print(f"Error processing PDF {file.name}: {e}") + elif file_extension in ('.ppt', '.pptx'): + try: + ppt_documents = process_ppt_file(save_uploaded_file(file)) + documents.extend(ppt_documents) + except Exception as e: + print(f"Error processing PPT {file.name}: {e}") + else: + text = file.read().decode("utf-8") + doc = Document(text=text, metadata={"source": file.name, "type": "text"}) + documents.append(doc) + return documents + +def load_data_from_directory(directory): + """Load and process multiple file types from a directory.""" + documents = [] + for filename in os.listdir(directory): + filepath = os.path.join(directory, filename) + file_extension = os.path.splitext(filename.lower())[1] + print(filename) + if file_extension in ('.png', '.jpg', '.jpeg'): + with open(filepath, "rb") as image_file: + image_content = image_file.read() + image_text = describe_image(image_content) + doc = Document(text=image_text, metadata={"source": filename, "type": "image"}) + print(doc) + documents.append(doc) + elif file_extension == '.pdf': + with open(filepath, "rb") as pdf_file: + try: + pdf_documents = get_pdf_documents(pdf_file) + documents.extend(pdf_documents) + except Exception as e: + print(f"Error processing PDF {filename}: {e}") + elif file_extension in ('.ppt', '.pptx'): + try: + ppt_documents = process_ppt_file(filepath) + documents.extend(ppt_documents) + print(ppt_documents) + except Exception as e: + print(f"Error processing PPT {filename}: {e}") + else: + with open(filepath, "r", encoding="utf-8") as text_file: + text = text_file.read() + doc = Document(text=text, metadata={"source": filename, "type": "text"}) + documents.append(doc) + return documents \ No newline at end of file diff --git a/community/multimodal-rag/requirements.txt b/community/multimodal-rag/requirements.txt new file mode 100644 index 00000000..942d55d2 --- /dev/null +++ b/community/multimodal-rag/requirements.txt @@ -0,0 +1,13 @@ +pymupdf==1.22.5 +streamlit==1.38.0 +fitz==0.0.1.dev2 +python-pptx==1.0.2 +Pillow==10.4.0 +requests==2.32.3 +frontend==0.0.3 +llama-index-core==0.10.58 +llama-index-readers-file==0.1.30 +llama-index-llms-nvidia==0.1.4 +llama-index-embeddings-nvidia==0.1.4 +llama-index-vector-stores-milvus==0.1.20 +pymilvus==2.4.4 diff --git a/community/multimodal-rag/utils.py b/community/multimodal-rag/utils.py new file mode 100644 index 00000000..61d3d383 --- /dev/null +++ b/community/multimodal-rag/utils.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import base64 +import fitz +from io import BytesIO +from PIL import Image +import requests +from llama_index.llms.nvidia import NVIDIA + +def set_environment_variables(): + """Set necessary environment variables.""" + os.environ["NVIDIA_API_KEY"] = "" #set API key + +def get_b64_image_from_content(image_content): + """Convert image content to base64 encoded string.""" + img = Image.open(BytesIO(image_content)) + if img.mode != 'RGB': + img = img.convert('RGB') + buffered = BytesIO() + img.save(buffered, format="JPEG") + return base64.b64encode(buffered.getvalue()).decode("utf-8") + +def is_graph(image_content): + """Determine if an image is a graph, plot, chart, or table.""" + res = describe_image(image_content) + return any(keyword in res.lower() for keyword in ["graph", "plot", "chart", "table"]) + +def process_graph(image_content): + """Process a graph image and generate a description.""" + deplot_description = process_graph_deplot(image_content) + mixtral = NVIDIA(model_name="meta/llama-3.1-70b-instruct") + response = mixtral.complete("Your responsibility is to explain charts. You are an expert in describing the responses of linearized tables into plain English text for LLMs to use. Explain the following linearized table. " + deplot_description) + return response.text + +def describe_image(image_content): + """Generate a description of an image using NVIDIA API.""" + image_b64 = get_b64_image_from_content(image_content) + invoke_url = "https://ai.api.nvidia.com/v1/vlm/nvidia/neva-22b" + api_key = os.getenv("NVIDIA_API_KEY") + + if not api_key: + raise ValueError("NVIDIA API Key is not set. Please set the NVIDIA_API_KEY environment variable.") + + headers = { + "Authorization": f"Bearer {api_key}", + "Accept": "application/json" + } + + payload = { + "messages": [ + { + "role": "user", + "content": f'Describe what you see in this image. <img src="data:image/png;base64,{image_b64}" />' + } + ], + "max_tokens": 1024, + "temperature": 0.20, + "top_p": 0.70, + "seed": 0, + "stream": False + } + + response = requests.post(invoke_url, headers=headers, json=payload) + return response.json()["choices"][0]['message']['content'] + +def process_graph_deplot(image_content): + """Process a graph image using NVIDIA's Deplot API.""" + invoke_url = "https://ai.api.nvidia.com/v1/vlm/google/deplot" + image_b64 = get_b64_image_from_content(image_content) + api_key = os.getenv("NVIDIA_API_KEY") + + if not api_key: + raise ValueError("NVIDIA API Key is not set. Please set the NVIDIA_API_KEY environment variable.") + + headers = { + "Authorization": f"Bearer {api_key}", + "Accept": "application/json" + } + + payload = { + "messages": [ + { + "role": "user", + "content": f'Generate underlying data table of the figure below: <img src="data:image/png;base64,{image_b64}" />' + } + ], + "max_tokens": 1024, + "temperature": 0.20, + "top_p": 0.20, + "stream": False + } + + response = requests.post(invoke_url, headers=headers, json=payload) + return response.json()["choices"][0]['message']['content'] + +def extract_text_around_item(text_blocks, bbox, page_height, threshold_percentage=0.1): + """Extract text above and below a given bounding box on a page.""" + before_text, after_text = "", "" + vertical_threshold_distance = page_height * threshold_percentage + horizontal_threshold_distance = bbox.width * threshold_percentage + + for block in text_blocks: + block_bbox = fitz.Rect(block[:4]) + vertical_distance = min(abs(block_bbox.y1 - bbox.y0), abs(block_bbox.y0 - bbox.y1)) + horizontal_overlap = max(0, min(block_bbox.x1, bbox.x1) - max(block_bbox.x0, bbox.x0)) + + if vertical_distance <= vertical_threshold_distance and horizontal_overlap >= -horizontal_threshold_distance: + if block_bbox.y1 < bbox.y0 and not before_text: + before_text = block[4] + elif block_bbox.y0 > bbox.y1 and not after_text: + after_text = block[4] + break + + return before_text, after_text + +def process_text_blocks(text_blocks, char_count_threshold=500): + """Group text blocks based on a character count threshold.""" + current_group = [] + grouped_blocks = [] + current_char_count = 0 + + for block in text_blocks: + if block[-1] == 0: # Check if the block is of text type + block_text = block[4] + block_char_count = len(block_text) + + if current_char_count + block_char_count <= char_count_threshold: + current_group.append(block) + current_char_count += block_char_count + else: + if current_group: + grouped_content = "\n".join([b[4] for b in current_group]) + grouped_blocks.append((current_group[0], grouped_content)) + current_group = [block] + current_char_count = block_char_count + + # Append the last group + if current_group: + grouped_content = "\n".join([b[4] for b in current_group]) + grouped_blocks.append((current_group[0], grouped_content)) + + return grouped_blocks + +def save_uploaded_file(uploaded_file): + """Save an uploaded file to a temporary directory.""" + temp_dir = os.path.join(os.getcwd(), "vectorstore", "ppt_references", "tmp") + os.makedirs(temp_dir, exist_ok=True) + temp_file_path = os.path.join(temp_dir, uploaded_file.name) + + with open(temp_file_path, "wb") as temp_file: + temp_file.write(uploaded_file.read()) + + return temp_file_path \ No newline at end of file diff --git a/community/multimodal-rag/vectorstore/.gitkeep b/community/multimodal-rag/vectorstore/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/community/multimodal-rag/vectorstore/image_references/.gitkeep b/community/multimodal-rag/vectorstore/image_references/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/community/multimodal-rag/vectorstore/ppt_references/.gitkeep b/community/multimodal-rag/vectorstore/ppt_references/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/community/multimodal-rag/vectorstore/table_references/.gitkeep b/community/multimodal-rag/vectorstore/table_references/.gitkeep new file mode 100644 index 00000000..e69de29b