-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0647dc9
commit 13c8db5
Showing
11 changed files
with
366 additions
and
139 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
## RAG Assistant | ||
|
||
1. Run pgvector | ||
|
||
```shell | ||
phi start cookbook/rag/resources.py -y | ||
``` | ||
|
||
2. Install libraries | ||
|
||
```shell | ||
pip install -U pgvector pypdf psycopg sqlalchemy | ||
``` | ||
|
||
3. Run RAG Assistant | ||
|
||
```shell | ||
python cookbook/rag/assistant.py | ||
``` | ||
|
||
4. Turn off pgvector | ||
|
||
```shell | ||
phi stop cookbook/rag/resources.py -y | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
from typing import List | ||
|
||
import streamlit as st | ||
from phi.assistant import Assistant | ||
from phi.document import Document | ||
from phi.document.reader.pdf import PDFReader | ||
from phi.tools.streamlit.components import ( | ||
get_openai_key_sidebar, | ||
check_password, | ||
reload_button_sidebar, | ||
get_username_sidebar, | ||
) | ||
|
||
from assistant import get_local_rag_assistant # type: ignore | ||
from logging import getLogger | ||
|
||
logger = getLogger(__name__) | ||
|
||
st.set_page_config( | ||
page_title="Local RAG", | ||
page_icon=":orange_heart:", | ||
) | ||
st.title("Local RAG") | ||
st.markdown("##### :orange_heart: built using [phidata](https://github.com/phidatahq/phidata)") | ||
|
||
|
||
def restart_assistant(): | ||
st.session_state["local_rag_assistant"] = None | ||
st.session_state["local_rag_assistant_run_id"] = None | ||
st.session_state["file_uploader_key"] += 1 | ||
st.rerun() | ||
|
||
|
||
def main() -> None: | ||
# Get OpenAI key from environment variable or user input | ||
get_openai_key_sidebar() | ||
|
||
# Get username | ||
username = get_username_sidebar() | ||
if username: | ||
st.sidebar.info(f":technologist: User: {username}") | ||
else: | ||
st.write(":technologist: Please enter a username") | ||
return | ||
|
||
# Get model | ||
local_rag_model = st.sidebar.selectbox("Select Model", options=["openhermes", "llama2"]) | ||
# Set assistant_type in session state | ||
if "local_rag_model" not in st.session_state: | ||
st.session_state["local_rag_model"] = local_rag_model | ||
# Restart the assistant if assistant_type has changed | ||
elif st.session_state["local_rag_model"] != local_rag_model: | ||
st.session_state["local_rag_model"] = local_rag_model | ||
restart_assistant() | ||
|
||
# Get the assistant | ||
local_rag_assistant: Assistant | ||
if "local_rag_assistant" not in st.session_state or st.session_state["local_rag_assistant"] is None: | ||
logger.info(f"---*--- Creating {local_rag_model} Assistant ---*---") | ||
local_rag_assistant = get_local_rag_assistant( | ||
model=local_rag_model, | ||
user_id=username, | ||
debug_mode=True, | ||
) | ||
st.session_state["local_rag_assistant"] = local_rag_assistant | ||
else: | ||
local_rag_assistant = st.session_state["local_rag_assistant"] | ||
|
||
# Create assistant run (i.e. log to database) and save run_id in session state | ||
st.session_state["local_rag_assistant_run_id"] = local_rag_assistant.create_run() | ||
|
||
# Load existing messages | ||
assistant_chat_history = local_rag_assistant.memory.get_chat_history() | ||
if len(assistant_chat_history) > 0: | ||
logger.debug("Loading chat history") | ||
st.session_state["messages"] = assistant_chat_history | ||
else: | ||
logger.debug("No chat history found") | ||
st.session_state["messages"] = [{"role": "assistant", "content": "Ask me anything..."}] | ||
|
||
# Prompt for user input | ||
if prompt := st.chat_input(): | ||
st.session_state["messages"].append({"role": "user", "content": prompt}) | ||
|
||
# Display existing chat messages | ||
for message in st.session_state["messages"]: | ||
if message["role"] == "system": | ||
continue | ||
with st.chat_message(message["role"]): | ||
st.write(message["content"]) | ||
|
||
# If last message is from a user, generate a new response | ||
last_message = st.session_state["messages"][-1] | ||
if last_message.get("role") == "user": | ||
question = last_message["content"] | ||
with st.chat_message("assistant"): | ||
response = "" | ||
resp_container = st.empty() | ||
for delta in local_rag_assistant.run(question): | ||
response += delta # type: ignore | ||
resp_container.markdown(response) | ||
|
||
st.session_state["messages"].append({"role": "assistant", "content": response}) | ||
|
||
if st.sidebar.button("New Run"): | ||
restart_assistant() | ||
|
||
if local_rag_assistant.knowledge_base: | ||
if st.sidebar.button("Clear Knowledge Base"): | ||
local_rag_assistant.knowledge_base.vector_db.clear() | ||
st.session_state["local_rag_knowledge_base_loaded"] = False | ||
st.sidebar.success("Knowledge base cleared") | ||
|
||
if st.sidebar.button("Auto Rename"): | ||
local_rag_assistant.auto_rename_run() | ||
|
||
# Upload PDF | ||
if local_rag_assistant.knowledge_base: | ||
if "file_uploader_key" not in st.session_state: | ||
st.session_state["file_uploader_key"] = 0 | ||
|
||
uploaded_file = st.sidebar.file_uploader( | ||
"Upload PDF", | ||
type="pdf", | ||
key=st.session_state["file_uploader_key"], | ||
) | ||
if uploaded_file is not None: | ||
alert = st.sidebar.info("Processing PDF...", icon="🧠") | ||
local_rag_name = uploaded_file.name.split(".")[0] | ||
if f"{local_rag_name}_uploaded" not in st.session_state: | ||
reader = PDFReader() | ||
local_rag_documents: List[Document] = reader.read(uploaded_file) | ||
if local_rag_documents: | ||
local_rag_assistant.knowledge_base.load_documents(local_rag_documents) | ||
else: | ||
st.sidebar.error("Could not read PDF") | ||
st.session_state[f"{local_rag_name}_uploaded"] = True | ||
alert.empty() | ||
|
||
if local_rag_assistant.storage: | ||
local_rag_assistant_run_ids: List[str] = local_rag_assistant.storage.get_all_run_ids(user_id=username) | ||
new_local_rag_assistant_run_id = st.sidebar.selectbox("Run ID", options=local_rag_assistant_run_ids) | ||
if st.session_state["local_rag_assistant_run_id"] != new_local_rag_assistant_run_id: | ||
logger.info(f"---*--- Loading {local_rag_model} run: {new_local_rag_assistant_run_id} ---*---") | ||
st.session_state["local_rag_assistant"] = get_local_rag_assistant( | ||
model=local_rag_model, | ||
user_id=username, | ||
debug_mode=True, | ||
) | ||
st.rerun() | ||
|
||
local_rag_assistant_run_name = local_rag_assistant.run_name | ||
if local_rag_assistant_run_name: | ||
st.sidebar.write(f":thread: {local_rag_assistant_run_name}") | ||
|
||
# Show reload button | ||
reload_button_sidebar() | ||
|
||
|
||
if check_password(): | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from typing import Optional | ||
|
||
from phi.assistant import Assistant | ||
from phi.knowledge import AssistantKnowledge | ||
from phi.llm.ollama import Ollama | ||
from phi.embedder.ollama import OllamaEmbedder | ||
from phi.vectordb.pgvector import PgVector2 | ||
from phi.storage.assistant.postgres import PgAssistantStorage | ||
|
||
from resources import vector_db # type: ignore | ||
|
||
local_assistant_storage = PgAssistantStorage( | ||
db_url=vector_db.get_db_connection_local(), | ||
# Store assistant runs in table: ai.local_rag_assistant | ||
table_name="local_rag_assistant", | ||
) | ||
|
||
|
||
def get_knowledge_base_for_model(model: str) -> AssistantKnowledge: | ||
return AssistantKnowledge( | ||
vector_db=PgVector2( | ||
db_url=vector_db.get_db_connection_local(), | ||
# Store embeddings in table: ai.local_{model}_documents | ||
collection=f"local_{model}_documents", | ||
# Use the OllamaEmbedder to get embeddings | ||
embedder=OllamaEmbedder(model=model), | ||
), | ||
# 5 references are added to the prompt | ||
num_documents=5, | ||
) | ||
|
||
|
||
def get_local_rag_assistant( | ||
model: str = "openhermes", | ||
user_id: Optional[str] = None, | ||
run_id: Optional[str] = None, | ||
debug_mode: bool = False, | ||
) -> Assistant: | ||
"""Get a Local RAG Assistant.""" | ||
|
||
return Assistant( | ||
name="local_rag_assistant", | ||
run_id=run_id, | ||
user_id=user_id, | ||
llm=Ollama(model=model), | ||
storage=local_assistant_storage, | ||
knowledge_base=get_knowledge_base_for_model(model), | ||
# This setting adds references from the knowledge_base to the user prompt | ||
add_references_to_prompt=True, | ||
# This setting adds the last 4 messages from the chat history to the API call | ||
add_chat_history_to_messages=True, | ||
num_history_messages=4, | ||
# This setting tells the LLM to format messages in markdown | ||
markdown=True, | ||
debug_mode=debug_mode, | ||
description="You are a AI called 'Phi' designed to help users answer questions from a knowledge base of PDFs.", | ||
assistant_data={"assistant_type": "rag"}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from phi.assistant import Assistant | ||
from phi.llm.ollama import Ollama | ||
from phi.vectordb.pgvector import PgVector2 | ||
from phi.embedder.ollama import OllamaEmbedder | ||
from phi.knowledge.pdf import PDFUrlKnowledgeBase | ||
|
||
from resources import vector_db # type: ignore | ||
|
||
model = "openhermes" | ||
knowledge_base = PDFUrlKnowledgeBase( | ||
urls=["https://phi-public.s3.amazonaws.com/recipes/ThaiRecipes.pdf"], | ||
vector_db=PgVector2( | ||
collection="recipes", db_url=vector_db.get_db_connection_local(), embedder=OllamaEmbedder(model=model) | ||
), | ||
) | ||
# knowledge_base.load(recreate=False) | ||
|
||
assistant = Assistant( | ||
llm=Ollama(model=model), | ||
knowledge_base=knowledge_base, | ||
add_references_to_prompt=True, | ||
debug_mode=True, | ||
) | ||
|
||
assistant.print_response("Got any pad thai?") | ||
|
||
# Use this to run a CLI application with multi-turn conversation | ||
# assistant.cli_app() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from phi.docker.app.postgres import PgVectorDb | ||
from phi.docker.resources import DockerResources | ||
|
||
# -*- PgVector running on port 5432:5432 | ||
vector_db = PgVectorDb( | ||
pg_user="ai", | ||
pg_password="ai", | ||
pg_database="ai", | ||
debug_mode=True, | ||
) | ||
|
||
# -*- DockerResources | ||
dev_docker_resources = DockerResources(apps=[vector_db]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.