Skip to content

Commit

Permalink
Ollama embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
ashpreetbedi committed Feb 7, 2024
1 parent 0647dc9 commit 13c8db5
Show file tree
Hide file tree
Showing 11 changed files with 366 additions and 139 deletions.
25 changes: 25 additions & 0 deletions cookbook/local_rag/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
## RAG Assistant

1. Run pgvector

```shell
phi start cookbook/rag/resources.py -y
```

2. Install libraries

```shell
pip install -U pgvector pypdf psycopg sqlalchemy
```

3. Run RAG Assistant

```shell
python cookbook/rag/assistant.py
```

4. Turn off pgvector

```shell
phi stop cookbook/rag/resources.py -y
```
Empty file added cookbook/local_rag/__init__.py
Empty file.
161 changes: 161 additions & 0 deletions cookbook/local_rag/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from typing import List

import streamlit as st
from phi.assistant import Assistant
from phi.document import Document
from phi.document.reader.pdf import PDFReader
from phi.tools.streamlit.components import (
get_openai_key_sidebar,
check_password,
reload_button_sidebar,
get_username_sidebar,
)

from assistant import get_local_rag_assistant # type: ignore
from logging import getLogger

logger = getLogger(__name__)

st.set_page_config(
page_title="Local RAG",
page_icon=":orange_heart:",
)
st.title("Local RAG")
st.markdown("##### :orange_heart: built using [phidata](https://github.com/phidatahq/phidata)")


def restart_assistant():
st.session_state["local_rag_assistant"] = None
st.session_state["local_rag_assistant_run_id"] = None
st.session_state["file_uploader_key"] += 1
st.rerun()


def main() -> None:
# Get OpenAI key from environment variable or user input
get_openai_key_sidebar()

# Get username
username = get_username_sidebar()
if username:
st.sidebar.info(f":technologist: User: {username}")
else:
st.write(":technologist: Please enter a username")
return

# Get model
local_rag_model = st.sidebar.selectbox("Select Model", options=["openhermes", "llama2"])
# Set assistant_type in session state
if "local_rag_model" not in st.session_state:
st.session_state["local_rag_model"] = local_rag_model
# Restart the assistant if assistant_type has changed
elif st.session_state["local_rag_model"] != local_rag_model:
st.session_state["local_rag_model"] = local_rag_model
restart_assistant()

# Get the assistant
local_rag_assistant: Assistant
if "local_rag_assistant" not in st.session_state or st.session_state["local_rag_assistant"] is None:
logger.info(f"---*--- Creating {local_rag_model} Assistant ---*---")
local_rag_assistant = get_local_rag_assistant(
model=local_rag_model,
user_id=username,
debug_mode=True,
)
st.session_state["local_rag_assistant"] = local_rag_assistant
else:
local_rag_assistant = st.session_state["local_rag_assistant"]

# Create assistant run (i.e. log to database) and save run_id in session state
st.session_state["local_rag_assistant_run_id"] = local_rag_assistant.create_run()

# Load existing messages
assistant_chat_history = local_rag_assistant.memory.get_chat_history()
if len(assistant_chat_history) > 0:
logger.debug("Loading chat history")
st.session_state["messages"] = assistant_chat_history
else:
logger.debug("No chat history found")
st.session_state["messages"] = [{"role": "assistant", "content": "Ask me anything..."}]

# Prompt for user input
if prompt := st.chat_input():
st.session_state["messages"].append({"role": "user", "content": prompt})

# Display existing chat messages
for message in st.session_state["messages"]:
if message["role"] == "system":
continue
with st.chat_message(message["role"]):
st.write(message["content"])

# If last message is from a user, generate a new response
last_message = st.session_state["messages"][-1]
if last_message.get("role") == "user":
question = last_message["content"]
with st.chat_message("assistant"):
response = ""
resp_container = st.empty()
for delta in local_rag_assistant.run(question):
response += delta # type: ignore
resp_container.markdown(response)

st.session_state["messages"].append({"role": "assistant", "content": response})

if st.sidebar.button("New Run"):
restart_assistant()

if local_rag_assistant.knowledge_base:
if st.sidebar.button("Clear Knowledge Base"):
local_rag_assistant.knowledge_base.vector_db.clear()
st.session_state["local_rag_knowledge_base_loaded"] = False
st.sidebar.success("Knowledge base cleared")

if st.sidebar.button("Auto Rename"):
local_rag_assistant.auto_rename_run()

# Upload PDF
if local_rag_assistant.knowledge_base:
if "file_uploader_key" not in st.session_state:
st.session_state["file_uploader_key"] = 0

uploaded_file = st.sidebar.file_uploader(
"Upload PDF",
type="pdf",
key=st.session_state["file_uploader_key"],
)
if uploaded_file is not None:
alert = st.sidebar.info("Processing PDF...", icon="🧠")
local_rag_name = uploaded_file.name.split(".")[0]
if f"{local_rag_name}_uploaded" not in st.session_state:
reader = PDFReader()
local_rag_documents: List[Document] = reader.read(uploaded_file)
if local_rag_documents:
local_rag_assistant.knowledge_base.load_documents(local_rag_documents)
else:
st.sidebar.error("Could not read PDF")
st.session_state[f"{local_rag_name}_uploaded"] = True
alert.empty()

if local_rag_assistant.storage:
local_rag_assistant_run_ids: List[str] = local_rag_assistant.storage.get_all_run_ids(user_id=username)
new_local_rag_assistant_run_id = st.sidebar.selectbox("Run ID", options=local_rag_assistant_run_ids)
if st.session_state["local_rag_assistant_run_id"] != new_local_rag_assistant_run_id:
logger.info(f"---*--- Loading {local_rag_model} run: {new_local_rag_assistant_run_id} ---*---")
st.session_state["local_rag_assistant"] = get_local_rag_assistant(
model=local_rag_model,
user_id=username,
debug_mode=True,
)
st.rerun()

local_rag_assistant_run_name = local_rag_assistant.run_name
if local_rag_assistant_run_name:
st.sidebar.write(f":thread: {local_rag_assistant_run_name}")

# Show reload button
reload_button_sidebar()


if check_password():
main()
58 changes: 58 additions & 0 deletions cookbook/local_rag/assistant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Optional

from phi.assistant import Assistant
from phi.knowledge import AssistantKnowledge
from phi.llm.ollama import Ollama
from phi.embedder.ollama import OllamaEmbedder
from phi.vectordb.pgvector import PgVector2
from phi.storage.assistant.postgres import PgAssistantStorage

from resources import vector_db # type: ignore

local_assistant_storage = PgAssistantStorage(
db_url=vector_db.get_db_connection_local(),
# Store assistant runs in table: ai.local_rag_assistant
table_name="local_rag_assistant",
)


def get_knowledge_base_for_model(model: str) -> AssistantKnowledge:
return AssistantKnowledge(
vector_db=PgVector2(
db_url=vector_db.get_db_connection_local(),
# Store embeddings in table: ai.local_{model}_documents
collection=f"local_{model}_documents",
# Use the OllamaEmbedder to get embeddings
embedder=OllamaEmbedder(model=model),
),
# 5 references are added to the prompt
num_documents=5,
)


def get_local_rag_assistant(
model: str = "openhermes",
user_id: Optional[str] = None,
run_id: Optional[str] = None,
debug_mode: bool = False,
) -> Assistant:
"""Get a Local RAG Assistant."""

return Assistant(
name="local_rag_assistant",
run_id=run_id,
user_id=user_id,
llm=Ollama(model=model),
storage=local_assistant_storage,
knowledge_base=get_knowledge_base_for_model(model),
# This setting adds references from the knowledge_base to the user prompt
add_references_to_prompt=True,
# This setting adds the last 4 messages from the chat history to the API call
add_chat_history_to_messages=True,
num_history_messages=4,
# This setting tells the LLM to format messages in markdown
markdown=True,
debug_mode=debug_mode,
description="You are a AI called 'Phi' designed to help users answer questions from a knowledge base of PDFs.",
assistant_data={"assistant_type": "rag"},
)
28 changes: 28 additions & 0 deletions cookbook/local_rag/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from phi.assistant import Assistant
from phi.llm.ollama import Ollama
from phi.vectordb.pgvector import PgVector2
from phi.embedder.ollama import OllamaEmbedder
from phi.knowledge.pdf import PDFUrlKnowledgeBase

from resources import vector_db # type: ignore

model = "openhermes"
knowledge_base = PDFUrlKnowledgeBase(
urls=["https://phi-public.s3.amazonaws.com/recipes/ThaiRecipes.pdf"],
vector_db=PgVector2(
collection="recipes", db_url=vector_db.get_db_connection_local(), embedder=OllamaEmbedder(model=model)
),
)
# knowledge_base.load(recreate=False)

assistant = Assistant(
llm=Ollama(model=model),
knowledge_base=knowledge_base,
add_references_to_prompt=True,
debug_mode=True,
)

assistant.print_response("Got any pad thai?")

# Use this to run a CLI application with multi-turn conversation
# assistant.cli_app()
13 changes: 13 additions & 0 deletions cookbook/local_rag/resources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from phi.docker.app.postgres import PgVectorDb
from phi.docker.resources import DockerResources

# -*- PgVector running on port 5432:5432
vector_db = PgVectorDb(
pg_user="ai",
pg_password="ai",
pg_database="ai",
debug_mode=True,
)

# -*- DockerResources
dev_docker_resources = DockerResources(apps=[vector_db])
4 changes: 4 additions & 0 deletions phi/assistant/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,8 +745,12 @@ def print_response(
from rich.box import ROUNDED
from rich.markdown import Markdown

if markdown:
self.markdown = True

if self.output_model is not None:
markdown = False
self.markdown = False
stream = False

if stream:
Expand Down
Loading

0 comments on commit 13c8db5

Please sign in to comment.