Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for PDF file uploads as context for LLM queries #3638

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Changes from 3 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 107 additions & 11 deletions fastchat/serve/gradio_block_arena_vision_anony.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
"""

import json
import subprocess
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
import time

import gradio as gr
import numpy as np
from typing import Union

import os
import PyPDF2

from fastchat.constants import (
TEXT_MODERATION_MSG,
IMAGE_MODERATION_MSG,
Expand Down Expand Up @@ -242,6 +246,71 @@ def clear_history(request: gr.Request):
+ [""]
)

def extract_text_from_pdf(pdf_file_path):
"""Extract text from a PDF file."""
try:
with open(pdf_file_path, 'rb') as f:
reader = PyPDF2.PdfReader(f)
pdf_text = ""
for page in reader.pages:
pdf_text += page.extract_text()
return pdf_text
except Exception as e:
logger.error(f"Failed to extract text from PDF: {e}")
return None
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved

import os
import nest_asyncio
from llama_parse import LlamaParse

nest_asyncio.apply() # Ensure compatibility with async environments
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved

def pdf_parse(pdf_path):
# Set API key, can also be configured in the environment
api_key = "LLAMA API"

# Initialize the LlamaParse object
parser = LlamaParse(
api_key=api_key,
result_type="markdown", # Output in Markdown format
num_workers=4, # Number of API calls for batch processing
verbose=True, # Print detailed logs
language="en" # Set language (default is English)
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
)

# Prepare the output directory and file name
output_dir = "outputs"
os.makedirs(output_dir, exist_ok=True)

pdf_name = os.path.splitext(os.path.basename(pdf_path))[0]
markdown_file_path = os.path.join(output_dir, f"{pdf_name}.md")
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved

# Load and parse the PDF
extra_info = {"file_name": pdf_name}

with open(pdf_path, "rb") as pdf_file:
# Pass the file object and extra info for parsing
documents = parser.load_data(pdf_file, extra_info=extra_info)

# Save the parsed content to a Markdown file
markdown_content = documents[0].text if documents else ""

return markdown_content

def wrap_query_context(user_query, query_context):
#TODO: refactor to split up user query and query context.
# lines = input.split("\n\n[USER QUERY]", 1)
# user_query = lines[1].strip()
# query_context = lines[0][len('[QUERY CONTEXT]\n\n'): ]
reformatted_query_context = (
f"[QUERY CONTEXT]\n"
f"<details>\n"
f"<summary>Expand context details</summary>\n\n"
f"{query_context}\n\n"
f"</details>"
)
markdown = reformatted_query_context + f"\n\n[USER QUERY]\n\n{user_query}"
return markdown

def add_text(
state0,
Expand All @@ -253,10 +322,14 @@ def add_text(
request: gr.Request,
):
if isinstance(chat_input, dict):
text, images = chat_input["text"], chat_input["files"]
text, files = chat_input["text"], chat_input["files"]
else:
text = chat_input
images = []
files = []

images = []

CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
file_extension = os.path.splitext(files[0])[1].lower()

ip = get_ip(request)
logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}")
Expand All @@ -267,7 +340,7 @@ def add_text(
if states[0] is None:
assert states[1] is None

if len(images) > 0:
if len(files) > 0 and file_extension != ".pdf":
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
model_left, model_right = get_battle_pair(
context.all_vision_models,
VISION_BATTLE_TARGETS,
Expand Down Expand Up @@ -350,7 +423,8 @@ def add_text(
+ [""]
)

text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
if file_extension != ".pdf":
text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
for i in range(num_sides):
post_processed_text = _prepare_text_with_image(
states[i], text, images, csam_flag=csam_flag
Expand All @@ -363,6 +437,27 @@ def add_text(
for i in range(num_sides):
if "deluxe" in states[i].model_name:
hint_msg = SLOW_MODEL_MSG

if file_extension == ".pdf":
document_text = pdf_parse(files[0])
post_processed_text = f"""
The following is the content of a document:

{document_text}

Based on this document, answer the following question:

{text}
"""

post_processed_text = wrap_query_context(text, post_processed_text)

# text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
for i in range(num_sides):
states[i].conv.append_message(states[i].conv.roles[0], post_processed_text)
states[i].conv.append_message(states[i].conv.roles[1], None)
states[i].skip_next = False

return (
states
+ [x.to_gradio_chatbot() for x in states]
Expand Down Expand Up @@ -471,10 +566,10 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
)

multimodal_textbox = gr.MultimodalTextbox(
file_types=["image"],
file_types=["file"],
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
show_label=False,
container=True,
placeholder="Enter your prompt or add image here",
placeholder="Enter your prompt here. You can also upload image or PDF file",
elem_id="input_box",
scale=3,
)
Expand All @@ -483,11 +578,12 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
)

with gr.Row() as button_row:
if random_questions:
global vqa_samples
with open(random_questions, "r") as f:
vqa_samples = json.load(f)
random_btn = gr.Button(value="🔮 Random Image", interactive=True)
random_btn = gr.Button(value="🔮 Random Image", interactive=True)
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
# if random_questions:
# global vqa_samples
# with open(random_questions, "r") as f:
# vqa_samples = json.load(f)
# random_btn = gr.Button(value="🔮 Random Image", interactive=True)
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
clear_btn = gr.Button(value="🎲 New Round", interactive=False)
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
share_btn = gr.Button(value="📷 Share")
Expand Down
Loading